Skip to content
Snippets Groups Projects
Commit c0b186ad authored by Maxence Naud's avatar Maxence Naud
Browse files

update 'Test_ConcatImpl.cpp'

parent 54621d46
No related branches found
No related tags found
2 merge requests!318[Upd] release verision 0.5.0,!307[UPD] Tensor formatting
Pipeline #63323 passed
...@@ -33,22 +33,27 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") { ...@@ -33,22 +33,27 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") {
std::shared_ptr<Tensor> input4 = std::make_shared<Tensor>(Array1D<int,5>{{ 11, 12, 13, 14, 15 }}); std::shared_ptr<Tensor> input4 = std::make_shared<Tensor>(Array1D<int,5>{{ 11, 12, 13, 14, 15 }});
std::shared_ptr<Tensor> input5 = std::make_shared<Tensor>(Array1D<int,6>{{ 16, 17, 18, 19, 20, 21 }}); std::shared_ptr<Tensor> input5 = std::make_shared<Tensor>(Array1D<int,6>{{ 16, 17, 18, 19, 20, 21 }});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array1D<int,20>{ Tensor expectedOutput = Array1D<int,20>{
{ 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,15,16,17,18,19,20,21 }}); { 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,15,16,17,18,19,20,21 }};
auto myConcat = Concat(5, 0); std::shared_ptr<Concat_Op> op = std::make_shared<Concat_Op>(5,0);
myConcat->getOperator()->associateInput(0, input1); op->associateInput(0, input1);
myConcat->getOperator()->associateInput(1, input2); op->associateInput(1, input2);
myConcat->getOperator()->associateInput(2, input3); op->associateInput(2, input3);
myConcat->getOperator()->associateInput(3, input4); op->associateInput(3, input4);
myConcat->getOperator()->associateInput(4, input5); op->associateInput(4, input5);
myConcat->getOperator()->setBackend("cpu"); op->setBackend("cpu");
myConcat->getOperator()->setDataType(DataType::Int32); op->setDataType(DataType::Int32);
myConcat->forward(); fmt::print("{}\n", *(op->getInput(0)));
fmt::print("{}\n", *(op->getInput(1)));
std::static_pointer_cast<Tensor>(myConcat->getOperator()->getRawOutput(0))->print(); fmt::print("{}\n", *(op->getInput(2)));
fmt::print("{}\n", *(op->getInput(3)));
REQUIRE(*std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0) == *expectedOutput); fmt::print("{}\n", *(op->getInput(4)));
op->forward();
fmt::print("res: {}\n", *(op->getOutput(0)));
REQUIRE(*(op->getOutput(0)) == expectedOutput);
} }
SECTION("Concat 4D inputs on 1st axis") { SECTION("Concat 4D inputs on 1st axis") {
std::shared_ptr<Tensor> input1 = std::make_shared<Tensor>(Array4D<int,1,3,3,2> { std::shared_ptr<Tensor> input1 = std::make_shared<Tensor>(Array4D<int,1,3,3,2> {
...@@ -75,7 +80,7 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") { ...@@ -75,7 +80,7 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") {
} // } //
}); // }); //
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<int,3,3,3,2> { Tensor expectedOutput = Array4D<int,3,3,3,2> {
{ // { //
{ // { //
{{20, 47},{21, 48},{22, 49}}, // {{20, 47},{21, 48},{22, 49}}, //
...@@ -93,18 +98,19 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") { ...@@ -93,18 +98,19 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") {
{{44, 71},{45, 72},{46, 73}} // {{44, 71},{45, 72},{46, 73}} //
} // } //
} // } //
}); // }; //
auto myConcat = Concat(2, 0); auto myConcat = Concat(2, 0);
myConcat->getOperator()->associateInput(0, input1); std::shared_ptr<Concat_Op> op = std::static_pointer_cast<Concat_Op>(myConcat->getOperator());
myConcat->getOperator()->associateInput(1, input2); op->associateInput(0, input1);
myConcat->getOperator()->setBackend("cpu"); op->associateInput(1, input2);
myConcat->getOperator()->setDataType(DataType::Int32); op->setBackend("cpu");
op->setDataType(DataType::Int32);
myConcat->forward(); myConcat->forward();
std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0)->print(); fmt::print("res: {}\n", *(op->getOutput(0)));
REQUIRE(*std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0) == *expectedOutput); REQUIRE(*(op->getOutput(0)) == expectedOutput);
} }
SECTION("Concat 4D inputs on 3rd axis") { SECTION("Concat 4D inputs on 3rd axis") {
...@@ -127,7 +133,7 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") { ...@@ -127,7 +133,7 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") {
} }
}); });
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<int,1,3,9,2> { Tensor expectedOutput = Array4D<int,1,3,9,2> {
{ // { //
{ // { //
{{20, 47},{21, 48},{22, 49},{29, 56},{30, 57},{31, 58},{38, 65},{39, 66},{40, 67}}, // {{20, 47},{21, 48},{22, 49},{29, 56},{30, 57},{31, 58},{38, 65},{39, 66},{40, 67}}, //
...@@ -135,17 +141,18 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") { ...@@ -135,17 +141,18 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") {
{{26, 53},{27, 54},{28, 55},{35, 62},{36, 63},{37, 64},{44, 71},{45, 72},{46, 73}} // {{26, 53},{27, 54},{28, 55},{35, 62},{36, 63},{37, 64},{44, 71},{45, 72},{46, 73}} //
}, // }, //
} // } //
}); // }; //
auto myConcat = Concat(2, 2); auto myConcat = Concat(2, 2);
myConcat->getOperator()->associateInput(0, input1); std::shared_ptr<Concat_Op> op = std::static_pointer_cast<Concat_Op>(myConcat->getOperator());
myConcat->getOperator()->associateInput(1, input2); op->associateInput(0, input1);
myConcat->getOperator()->setBackend("cpu"); op->associateInput(1, input2);
myConcat->getOperator()->setDataType(DataType::Int32); op->setBackend("cpu");
op->setDataType(DataType::Int32);
myConcat->forward(); myConcat->forward();
std::static_pointer_cast<Tensor>(myConcat->getOperator()->getRawOutput(0))->print(); std::static_pointer_cast<Tensor>(myConcat->getOperator()->getRawOutput(0))->print();
REQUIRE(*std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0) == *expectedOutput); REQUIRE(*(op->getOutput(0)) == expectedOutput);
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment