diff --git a/unit_tests/operator/Test_ConcatImpl.cpp b/unit_tests/operator/Test_ConcatImpl.cpp index fcdf3e8cc1bc07493cfa84608f200f9f334a29cc..677f78e54f850001ab648bfd03c3415b212ba3f2 100644 --- a/unit_tests/operator/Test_ConcatImpl.cpp +++ b/unit_tests/operator/Test_ConcatImpl.cpp @@ -33,22 +33,27 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") { std::shared_ptr<Tensor> input4 = std::make_shared<Tensor>(Array1D<int,5>{{ 11, 12, 13, 14, 15 }}); std::shared_ptr<Tensor> input5 = std::make_shared<Tensor>(Array1D<int,6>{{ 16, 17, 18, 19, 20, 21 }}); - std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array1D<int,20>{ - { 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,15,16,17,18,19,20,21 }}); - - auto myConcat = Concat(5, 0); - myConcat->getOperator()->associateInput(0, input1); - myConcat->getOperator()->associateInput(1, input2); - myConcat->getOperator()->associateInput(2, input3); - myConcat->getOperator()->associateInput(3, input4); - myConcat->getOperator()->associateInput(4, input5); - myConcat->getOperator()->setBackend("cpu"); - myConcat->getOperator()->setDataType(DataType::Int32); - myConcat->forward(); - - std::static_pointer_cast<Tensor>(myConcat->getOperator()->getRawOutput(0))->print(); - - REQUIRE(*std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0) == *expectedOutput); + Tensor expectedOutput = Array1D<int,20>{ + { 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,15,16,17,18,19,20,21 }}; + + std::shared_ptr<Concat_Op> op = std::make_shared<Concat_Op>(5,0); + op->associateInput(0, input1); + op->associateInput(1, input2); + op->associateInput(2, input3); + op->associateInput(3, input4); + op->associateInput(4, input5); + op->setBackend("cpu"); + op->setDataType(DataType::Int32); + fmt::print("{}\n", *(op->getInput(0))); + fmt::print("{}\n", *(op->getInput(1))); + fmt::print("{}\n", *(op->getInput(2))); + fmt::print("{}\n", *(op->getInput(3))); + fmt::print("{}\n", *(op->getInput(4))); + op->forward(); + + fmt::print("res: {}\n", *(op->getOutput(0))); + + REQUIRE(*(op->getOutput(0)) == expectedOutput); } SECTION("Concat 4D inputs on 1st axis") { std::shared_ptr<Tensor> input1 = std::make_shared<Tensor>(Array4D<int,1,3,3,2> { @@ -75,7 +80,7 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") { } // }); // - std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<int,3,3,3,2> { + Tensor expectedOutput = Array4D<int,3,3,3,2> { { // { // {{20, 47},{21, 48},{22, 49}}, // @@ -93,18 +98,19 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") { {{44, 71},{45, 72},{46, 73}} // } // } // - }); // + }; // auto myConcat = Concat(2, 0); - myConcat->getOperator()->associateInput(0, input1); - myConcat->getOperator()->associateInput(1, input2); - myConcat->getOperator()->setBackend("cpu"); - myConcat->getOperator()->setDataType(DataType::Int32); + std::shared_ptr<Concat_Op> op = std::static_pointer_cast<Concat_Op>(myConcat->getOperator()); + op->associateInput(0, input1); + op->associateInput(1, input2); + op->setBackend("cpu"); + op->setDataType(DataType::Int32); myConcat->forward(); - std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0)->print(); + fmt::print("res: {}\n", *(op->getOutput(0))); - REQUIRE(*std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0) == *expectedOutput); + REQUIRE(*(op->getOutput(0)) == expectedOutput); } SECTION("Concat 4D inputs on 3rd axis") { @@ -127,7 +133,7 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") { } }); - std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<int,1,3,9,2> { + Tensor expectedOutput = Array4D<int,1,3,9,2> { { // { // {{20, 47},{21, 48},{22, 49},{29, 56},{30, 57},{31, 58},{38, 65},{39, 66},{40, 67}}, // @@ -135,17 +141,18 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") { {{26, 53},{27, 54},{28, 55},{35, 62},{36, 63},{37, 64},{44, 71},{45, 72},{46, 73}} // }, // } // - }); // + }; // auto myConcat = Concat(2, 2); - myConcat->getOperator()->associateInput(0, input1); - myConcat->getOperator()->associateInput(1, input2); - myConcat->getOperator()->setBackend("cpu"); - myConcat->getOperator()->setDataType(DataType::Int32); + std::shared_ptr<Concat_Op> op = std::static_pointer_cast<Concat_Op>(myConcat->getOperator()); + op->associateInput(0, input1); + op->associateInput(1, input2); + op->setBackend("cpu"); + op->setDataType(DataType::Int32); myConcat->forward(); std::static_pointer_cast<Tensor>(myConcat->getOperator()->getRawOutput(0))->print(); - REQUIRE(*std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0) == *expectedOutput); + REQUIRE(*(op->getOutput(0)) == expectedOutput); } }