Skip to content
Snippets Groups Projects

[UPD] Tensor formatting

Merged Maxence Naud requested to merge feat_precision-log into dev
1 file
+ 38
31
Compare changes
  • Side-by-side
  • Inline
@@ -33,22 +33,27 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") {
std::shared_ptr<Tensor> input4 = std::make_shared<Tensor>(Array1D<int,5>{{ 11, 12, 13, 14, 15 }});
std::shared_ptr<Tensor> input5 = std::make_shared<Tensor>(Array1D<int,6>{{ 16, 17, 18, 19, 20, 21 }});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array1D<int,20>{
{ 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,15,16,17,18,19,20,21 }});
auto myConcat = Concat(5, 0);
myConcat->getOperator()->associateInput(0, input1);
myConcat->getOperator()->associateInput(1, input2);
myConcat->getOperator()->associateInput(2, input3);
myConcat->getOperator()->associateInput(3, input4);
myConcat->getOperator()->associateInput(4, input5);
myConcat->getOperator()->setBackend("cpu");
myConcat->getOperator()->setDataType(DataType::Int32);
myConcat->forward();
std::static_pointer_cast<Tensor>(myConcat->getOperator()->getRawOutput(0))->print();
REQUIRE(*std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0) == *expectedOutput);
Tensor expectedOutput = Array1D<int,20>{
{ 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,15,16,17,18,19,20,21 }};
std::shared_ptr<Concat_Op> op = std::make_shared<Concat_Op>(5,0);
op->associateInput(0, input1);
op->associateInput(1, input2);
op->associateInput(2, input3);
op->associateInput(3, input4);
op->associateInput(4, input5);
op->setBackend("cpu");
op->setDataType(DataType::Int32);
fmt::print("{}\n", *(op->getInput(0)));
fmt::print("{}\n", *(op->getInput(1)));
fmt::print("{}\n", *(op->getInput(2)));
fmt::print("{}\n", *(op->getInput(3)));
fmt::print("{}\n", *(op->getInput(4)));
op->forward();
fmt::print("res: {}\n", *(op->getOutput(0)));
REQUIRE(*(op->getOutput(0)) == expectedOutput);
}
SECTION("Concat 4D inputs on 1st axis") {
std::shared_ptr<Tensor> input1 = std::make_shared<Tensor>(Array4D<int,1,3,3,2> {
@@ -75,7 +80,7 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") {
} //
}); //
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<int,3,3,3,2> {
Tensor expectedOutput = Array4D<int,3,3,3,2> {
{ //
{ //
{{20, 47},{21, 48},{22, 49}}, //
@@ -93,18 +98,19 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") {
{{44, 71},{45, 72},{46, 73}} //
} //
} //
}); //
}; //
auto myConcat = Concat(2, 0);
myConcat->getOperator()->associateInput(0, input1);
myConcat->getOperator()->associateInput(1, input2);
myConcat->getOperator()->setBackend("cpu");
myConcat->getOperator()->setDataType(DataType::Int32);
std::shared_ptr<Concat_Op> op = std::static_pointer_cast<Concat_Op>(myConcat->getOperator());
op->associateInput(0, input1);
op->associateInput(1, input2);
op->setBackend("cpu");
op->setDataType(DataType::Int32);
myConcat->forward();
std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0)->print();
fmt::print("res: {}\n", *(op->getOutput(0)));
REQUIRE(*std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0) == *expectedOutput);
REQUIRE(*(op->getOutput(0)) == expectedOutput);
}
SECTION("Concat 4D inputs on 3rd axis") {
@@ -127,7 +133,7 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") {
}
});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<int,1,3,9,2> {
Tensor expectedOutput = Array4D<int,1,3,9,2> {
{ //
{ //
{{20, 47},{21, 48},{22, 49},{29, 56},{30, 57},{31, 58},{38, 65},{39, 66},{40, 67}}, //
@@ -135,17 +141,18 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") {
{{26, 53},{27, 54},{28, 55},{35, 62},{36, 63},{37, 64},{44, 71},{45, 72},{46, 73}} //
}, //
} //
}); //
}; //
auto myConcat = Concat(2, 2);
myConcat->getOperator()->associateInput(0, input1);
myConcat->getOperator()->associateInput(1, input2);
myConcat->getOperator()->setBackend("cpu");
myConcat->getOperator()->setDataType(DataType::Int32);
std::shared_ptr<Concat_Op> op = std::static_pointer_cast<Concat_Op>(myConcat->getOperator());
op->associateInput(0, input1);
op->associateInput(1, input2);
op->setBackend("cpu");
op->setDataType(DataType::Int32);
myConcat->forward();
std::static_pointer_cast<Tensor>(myConcat->getOperator()->getRawOutput(0))->print();
REQUIRE(*std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0) == *expectedOutput);
REQUIRE(*(op->getOutput(0)) == expectedOutput);
}
}
Loading