diff --git a/unit_tests/operator/Test_ConcatImpl.cpp b/unit_tests/operator/Test_ConcatImpl.cpp
index fcdf3e8cc1bc07493cfa84608f200f9f334a29cc..677f78e54f850001ab648bfd03c3415b212ba3f2 100644
--- a/unit_tests/operator/Test_ConcatImpl.cpp
+++ b/unit_tests/operator/Test_ConcatImpl.cpp
@@ -33,22 +33,27 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") {
         std::shared_ptr<Tensor> input4 = std::make_shared<Tensor>(Array1D<int,5>{{ 11, 12, 13, 14, 15 }});
         std::shared_ptr<Tensor> input5 = std::make_shared<Tensor>(Array1D<int,6>{{ 16, 17, 18, 19, 20, 21 }});
 
-        std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array1D<int,20>{
-            { 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,15,16,17,18,19,20,21 }});
-
-        auto myConcat = Concat(5, 0);
-        myConcat->getOperator()->associateInput(0, input1);
-        myConcat->getOperator()->associateInput(1, input2);
-        myConcat->getOperator()->associateInput(2, input3);
-        myConcat->getOperator()->associateInput(3, input4);
-        myConcat->getOperator()->associateInput(4, input5);
-        myConcat->getOperator()->setBackend("cpu");
-        myConcat->getOperator()->setDataType(DataType::Int32);
-        myConcat->forward();
-
-        std::static_pointer_cast<Tensor>(myConcat->getOperator()->getRawOutput(0))->print();
-
-        REQUIRE(*std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0) == *expectedOutput);
+        Tensor expectedOutput = Array1D<int,20>{
+            { 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,15,16,17,18,19,20,21 }};
+
+        std::shared_ptr<Concat_Op> op = std::make_shared<Concat_Op>(5,0);
+        op->associateInput(0, input1);
+        op->associateInput(1, input2);
+        op->associateInput(2, input3);
+        op->associateInput(3, input4);
+        op->associateInput(4, input5);
+        op->setBackend("cpu");
+        op->setDataType(DataType::Int32);
+        fmt::print("{}\n", *(op->getInput(0)));
+        fmt::print("{}\n", *(op->getInput(1)));
+        fmt::print("{}\n", *(op->getInput(2)));
+        fmt::print("{}\n", *(op->getInput(3)));
+        fmt::print("{}\n", *(op->getInput(4)));
+        op->forward();
+
+        fmt::print("res: {}\n", *(op->getOutput(0)));
+
+        REQUIRE(*(op->getOutput(0)) == expectedOutput);
     }
     SECTION("Concat 4D inputs on 1st axis") {
         std::shared_ptr<Tensor> input1 = std::make_shared<Tensor>(Array4D<int,1,3,3,2> {
@@ -75,7 +80,7 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") {
             }                                       //
         });                                         //
 
-        std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<int,3,3,3,2> {
+        Tensor expectedOutput = Array4D<int,3,3,3,2> {
             {                                       //
                 {                                   //
                     {{20, 47},{21, 48},{22, 49}},   //
@@ -93,18 +98,19 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") {
                     {{44, 71},{45, 72},{46, 73}}    //
                 }                                   //
             }                                       //
-        });                                         //
+        };                                         //
 
         auto myConcat = Concat(2, 0);
-        myConcat->getOperator()->associateInput(0, input1);
-        myConcat->getOperator()->associateInput(1, input2);
-        myConcat->getOperator()->setBackend("cpu");
-        myConcat->getOperator()->setDataType(DataType::Int32);
+        std::shared_ptr<Concat_Op> op = std::static_pointer_cast<Concat_Op>(myConcat->getOperator());
+        op->associateInput(0, input1);
+        op->associateInput(1, input2);
+        op->setBackend("cpu");
+        op->setDataType(DataType::Int32);
         myConcat->forward();
 
-        std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0)->print();
+        fmt::print("res: {}\n", *(op->getOutput(0)));
 
-        REQUIRE(*std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0) == *expectedOutput);
+        REQUIRE(*(op->getOutput(0)) == expectedOutput);
     }
 
     SECTION("Concat 4D inputs on 3rd axis") {
@@ -127,7 +133,7 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") {
             }
         });
 
-        std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<int,1,3,9,2> {
+        Tensor expectedOutput = Array4D<int,1,3,9,2> {
             {                                                                                             //
                 {                                                                                         //
                     {{20, 47},{21, 48},{22, 49},{29, 56},{30, 57},{31, 58},{38, 65},{39, 66},{40, 67}},   //
@@ -135,17 +141,18 @@ TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") {
                     {{26, 53},{27, 54},{28, 55},{35, 62},{36, 63},{37, 64},{44, 71},{45, 72},{46, 73}}    //
                 },                                                                                        //
             }                                                                                             //
-        });                                                                                               //
+        };                                                                                               //
 
         auto myConcat = Concat(2, 2);
-        myConcat->getOperator()->associateInput(0, input1);
-        myConcat->getOperator()->associateInput(1, input2);
-        myConcat->getOperator()->setBackend("cpu");
-        myConcat->getOperator()->setDataType(DataType::Int32);
+        std::shared_ptr<Concat_Op> op = std::static_pointer_cast<Concat_Op>(myConcat->getOperator());
+        op->associateInput(0, input1);
+        op->associateInput(1, input2);
+        op->setBackend("cpu");
+        op->setDataType(DataType::Int32);
         myConcat->forward();
 
         std::static_pointer_cast<Tensor>(myConcat->getOperator()->getRawOutput(0))->print();
 
-        REQUIRE(*std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0) == *expectedOutput);
+        REQUIRE(*(op->getOutput(0)) == expectedOutput);
     }
 }