diff --git a/unit_tests/operator/Test_ConcatImpl.cpp b/unit_tests/operator/Test_ConcatImpl.cpp index d46b7118c7905fd5ae2a9d413eaff51a97c7ed51..fe1302270a08d66349d6ce0ba4ed2ca6e0bd4420 100644 --- a/unit_tests/operator/Test_ConcatImpl.cpp +++ b/unit_tests/operator/Test_ConcatImpl.cpp @@ -19,6 +19,31 @@ using namespace Aidge; TEST_CASE("[cpu/operator] Concat(forward)", "[Concat]") { + SECTION("Concat 1D inputs") { + std::shared_ptr<Tensor> input1 = std::make_shared<Tensor>(Array1D<int,2>{{ 2, 3 }}); + std::shared_ptr<Tensor> input2 = std::make_shared<Tensor>(Array1D<int,3>{{ 4, 5, 6 }}); + std::shared_ptr<Tensor> input3 = std::make_shared<Tensor>(Array1D<int,4>{{ 7, 8, 9, 10 }}); + std::shared_ptr<Tensor> input4 = std::make_shared<Tensor>(Array1D<int,5>{{ 11, 12, 13, 14, 15 }}); + std::shared_ptr<Tensor> input5 = std::make_shared<Tensor>(Array1D<int,6>{{ 16, 17, 18, 19, 20, 21 }}); + + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array1D<int,20>{ + { 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,15,16,17,18,19,20,21 }}); + + auto myConcat = Concat(5, 0); + myConcat->getOperator()->setBackend("cpu"); + myConcat->getOperator()->setDatatype(DataType::Int32); + myConcat->getOperator()->associateInput(0, input1); + myConcat->getOperator()->associateInput(1, input2); + myConcat->getOperator()->associateInput(2, input3); + myConcat->getOperator()->associateInput(3, input4); + myConcat->getOperator()->associateInput(4, input5); + myConcat->getOperator()->computeOutputDims(); + myConcat->forward(); + + myConcat->getOperator()->getOutput(0)->print(); + + REQUIRE(myConcat->getOperator()->output(0) == *expectedOutput); + } SECTION("Concat 4D inputs on 1st axis") { std::shared_ptr<Tensor> input1 = std::make_shared<Tensor>(Array4D<int,1,3,3,2> { { //