From 25e14ea1c2746dd7433ac945682b6ab265fb0a2c Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Thu, 30 Nov 2023 15:39:59 +0100
Subject: [PATCH] keep merged concat operator

---
 .../aidge/backend/cpu/operator/ConcatImpl.hpp |   3 +-
 .../operator/ConcatImpl_forward_kernels.hpp   |  61 ++++---
 src/operator/ConcatImpl.cpp                   |   4 +-
 unit_tests/operator/Test_ConcatImpl.cpp       | 163 ++++++++++--------
 4 files changed, 126 insertions(+), 105 deletions(-)

diff --git a/include/aidge/backend/cpu/operator/ConcatImpl.hpp b/include/aidge/backend/cpu/operator/ConcatImpl.hpp
index 43ec45f1..a5e0c56e 100644
--- a/include/aidge/backend/cpu/operator/ConcatImpl.hpp
+++ b/include/aidge/backend/cpu/operator/ConcatImpl.hpp
@@ -20,7 +20,8 @@
 #include <vector>
 
 namespace Aidge {
-// class Concat_Op;
+// class Concat_Op<2>;
+
 // compute kernel registry for forward and backward
 class ConcatImplForward_cpu
     : public Registrable<ConcatImplForward_cpu, std::tuple<DataType, DataType>, void(const Concat_Op::Attrs&,
diff --git a/include/aidge/backend/cpu/operator/ConcatImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/ConcatImpl_forward_kernels.hpp
index 7651e055..b76f384b 100644
--- a/include/aidge/backend/cpu/operator/ConcatImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/ConcatImpl_forward_kernels.hpp
@@ -12,49 +12,56 @@
 #ifndef AIDGE_CPU_OPERATOR_CONCATIMPL_FORWARD_KERNEL_H_
 #define AIDGE_CPU_OPERATOR_CONCATIMPL_FORWARD_KERNEL_H_
 
-#include "aidge/utils/Registrar.hpp"
+#include <algorithm>
+#include <numeric>
 #include <cstddef>
-#include <cmath>
-#include "aidge/data/Data.hpp"
-#include "aidge/utils/Types.h"
+#include <vector>
 
 #include "aidge/backend/cpu/operator/ConcatImpl.hpp"
+#include "aidge/data/Data.hpp"
+#include "aidge/operator/Concat.hpp"
+#include "aidge/utils/Registrar.hpp"
+#include "aidge/utils/Types.h"
 
 namespace Aidge {
+
 template <class I, class O>
 void ConcatImpl_cpu_forward_kernel(const Concat_Op::Attrs& attrs,
-                                   const std::vector<DimSize_t>& inputDims,
+                                   const std::vector<DimSize_t>& dimsFirstInput,
                                    const std::vector<DimSize_t>& concatAxisValues,
                                    const std::vector<const void*>& inputs_,
                                    void* output_)
 {
-    std::size_t axisIdx = std::get<0>(attrs);
-    O* output = static_cast<O*>(output_);
-    std::vector<const I*> input;
-    for(const auto& elem:inputs_)
-    {
-        input.emplace_back(static_cast<const I*>(elem));
+    // FIXME: missing Concat attributes as arguments
+    std::vector<const I*> inputs;
+    for (const auto& input_ : inputs_) {
+        inputs.push_back(static_cast<const I*>(input_));
     }
+    O* output = static_cast<O*>(output_);
 
-    std::size_t postAxisElems = 1;
-    for (std::size_t i = axisIdx + 1; i < inputDims.size(); ++i) {
-        postAxisElems *= inputDims[i];
+    DimSize_t outputAxisValue = std::accumulate(concatAxisValues.begin(), concatAxisValues.end(), 0);
+
+    DimSize_t prodDimLower = 1;
+    for (DimIdx_t i = 0; i < std::get<0>(attrs); ++i) {
+        prodDimLower *= dimsFirstInput[i];
     }
-    std::size_t preAxisElems = 1;
-    for (std::size_t i = 0; i < axisIdx; ++i) {
-        preAxisElems *= inputDims[i];
+    DimSize_t prodDimHigher = 1;
+    for (DimIdx_t i = std::get<0>(attrs) + 1; static_cast<std::size_t>(i) < dimsFirstInput.size();
+         ++i) {
+        prodDimHigher *= dimsFirstInput[i];
     }
 
-    for(std::size_t i=0; i<preAxisElems ; ++i)
-    {
-		for(std::size_t j=0; j < input.size(); ++j)
-		{
-            std::size_t strideOnAxis = postAxisElems * concatAxisValues[j];
-            const I* copyPtr = std::next(input[j], i * strideOnAxis);
-            std::copy_n(copyPtr, strideOnAxis, output);
-            output += strideOnAxis;
-	    }
-	}
+    std::size_t oIndexStart = 0;
+    std::size_t oIndex = 0;
+    for (std::size_t inputId = 0; inputId < inputs.size(); ++inputId) {
+        oIndex = oIndexStart;
+        const DimSize_t iOffset = prodDimHigher*concatAxisValues[inputId];
+        for (std::size_t iIndex = 0; iIndex < prodDimLower; ++iIndex) {
+            std::copy(inputs[inputId] + iIndex*iOffset, inputs[inputId] + (iIndex+1)*iOffset, output + oIndex);
+            oIndex += prodDimHigher*outputAxisValue;
+        }
+        oIndexStart += concatAxisValues[inputId]*prodDimHigher;
+    }
 }
 
 namespace {
diff --git a/src/operator/ConcatImpl.cpp b/src/operator/ConcatImpl.cpp
index 16adfefc..d4605448 100644
--- a/src/operator/ConcatImpl.cpp
+++ b/src/operator/ConcatImpl.cpp
@@ -13,8 +13,6 @@
 #include <numeric> // std::accumulate
 #include <vector>
 
-#include "aidge/operator/Concat.hpp"
-
 #include "aidge/utils/Types.h"
 #include "aidge/data/Data.hpp"
 #include "aidge/data/Tensor.hpp"
@@ -88,4 +86,4 @@ void  Aidge::ConcatImpl_cpu::forward() {
                std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
 }
 
-void  Aidge::ConcatImpl_cpu::backward() { printf("Not implemented yet.\n"); }
+void  Aidge::ConcatImpl_cpu::backward() { printf("Not implemented yet.\n"); }
\ No newline at end of file
diff --git a/unit_tests/operator/Test_ConcatImpl.cpp b/unit_tests/operator/Test_ConcatImpl.cpp
index e3a3c93e..7f616fcb 100644
--- a/unit_tests/operator/Test_ConcatImpl.cpp
+++ b/unit_tests/operator/Test_ConcatImpl.cpp
@@ -12,15 +12,13 @@
 #include <catch2/catch_test_macros.hpp>
 
 #include "aidge/data/Tensor.hpp"
-#include "aidge/operator/Concat.hpp"
+#include "aidge/operator/Add.hpp"
 
 #include "aidge/backend/cpu.hpp"
 
-#include <memory>
-
 using namespace Aidge;
 
-TEST_CASE("[cpu/operator] Concat(forward)") {
+TEST_CASE("[cpu/operator] Concat(forward)", "[Concat][CPU]") {
     SECTION("Concat 1D inputs") {
         std::shared_ptr<Tensor> input1 = std::make_shared<Tensor>(Array1D<int,2>{{ 2, 3 }});
         std::shared_ptr<Tensor> input2 = std::make_shared<Tensor>(Array1D<int,3>{{ 4, 5, 6 }});
@@ -46,87 +44,104 @@ TEST_CASE("[cpu/operator] Concat(forward)") {
 
         REQUIRE(*std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0) == *expectedOutput);
     }
-        SECTION("2D Tensors") {
-        std::shared_ptr<Tensor> input1 = std::make_shared<Tensor>(Array2D<float,2,2> {
+    SECTION("Concat 4D inputs on 1st axis") {
+        std::shared_ptr<Tensor> input1 = std::make_shared<Tensor>(Array4D<int,1,3,3,2> {
+            {                                       //
+                {                                   //
+                    {{20, 47},{21, 48},{22, 49}},   //
+                    {{23, 50},{24, 51},{25, 52}},   //
+                    {{26, 53},{27, 54},{28, 55}}    //
+                },                                  //
+            }                                       //
+        });                                         //
+        std::shared_ptr<Tensor> input2 = std::make_shared<Tensor>(Array4D<int,2,3,3,2> {
             {
-                {0.00543531, 0.53726782},
-                {0.44371938, 0.93770550}
-            }
-        });
-        std::shared_ptr<Tensor> input2 = std::make_shared<Tensor>(Array2D<float,2,2> {
-            {
-                {0.87131297, 0.22378820},
-                {0.74409730, 0.72109798}
-            }
-        });
-        std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,2,4> {
-            {
-                {0.00543531, 0.53726782, 0.87131297, 0.22378820},
-                {0.44371938, 0.93770550, 0.74409730, 0.72109798}
-            }
-        });
+                {                                   //
+                    {{29, 56},{30, 57},{31, 58}},   //
+                    {{32, 59},{33, 60},{34, 61}},   //
+                    {{35, 62},{36, 63},{37, 64}}    //
+                },                                  //
+                {                                   //
+                    {{38, 65},{39, 66},{40, 67}},   //
+                    {{41, 68},{42, 69},{43, 70}},   //
+                    {{44, 71},{45, 72},{46, 73}}    //
+                }                                   //
+            }                                       //
+        });                                         //
 
-        std::shared_ptr<Node> myConcat = Concat(2, 1);
-        auto op = std::static_pointer_cast<OperatorTensor>(myConcat->getOperator());
-        op->associateInput(0,input1);
-        op->associateInput(1,input2);
-        op->setDataType(DataType::Float32);
-        op->setBackend("cpu");
-        op->computeOutputDims();
+        std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<int,3,3,3,2> {
+            {                                       //
+                {                                   //
+                    {{20, 47},{21, 48},{22, 49}},   //
+                    {{23, 50},{24, 51},{25, 52}},   //
+                    {{26, 53},{27, 54},{28, 55}}    //
+                },                                  //
+                {                                   //
+                    {{29, 56},{30, 57},{31, 58}},   //
+                    {{32, 59},{33, 60},{34, 61}},   //
+                    {{35, 62},{36, 63},{37, 64}}    //
+                },                                  //
+                {                                   //
+                    {{38, 65},{39, 66},{40, 67}},   //
+                    {{41, 68},{42, 69},{43, 70}},   //
+                    {{44, 71},{45, 72},{46, 73}}    //
+                }                                   //
+            }                                       //
+        });                                         //
+
+        auto myConcat = Concat(2, 0);
+        myConcat->getOperator()->associateInput(0, input1);
+        myConcat->getOperator()->associateInput(1, input2);
+        myConcat->getOperator()->setBackend("cpu");
+        myConcat->getOperator()->setDataType(DataType::Int32);
+        std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->computeOutputDims();
         myConcat->forward();
 
-        REQUIRE(*(op->getOutput(0)) == *expectedOutput);
+        std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0)->print();
+
+        REQUIRE(*std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0) == *expectedOutput);
     }
-    SECTION("3D Tensors") {
-        std::shared_ptr<Tensor> input1 = std::make_shared<Tensor>(Array3D<int,2,1,3> {
-            {
-                {
-                    {1, 2, 3}
-                },
-                {
-                    {4, 5, 6}
-                }
-            }
-        });
-        std::shared_ptr<Tensor> input2 = std::make_shared<Tensor>(Array3D<int,2,2,3> {
-            {
-                {
-                    {10, 11, 12},
-                    {13, 14, 15}
-                },
-                {
-                    {16, 17, 18},
-                    {19, 20, 21}
-                }
-            }
-        });
-        std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array3D<int,2,4,3> {
+
+    SECTION("Concat 4D inputs on 3rd axis") {
+        std::shared_ptr<Tensor> input1 = std::make_shared<Tensor>(Array4D<int,1,3,3,2> {
+            {                                       //
+                {                                   //
+                    {{20, 47},{21, 48},{22, 49}},   //
+                    {{23, 50},{24, 51},{25, 52}},   //
+                    {{26, 53},{27, 54},{28, 55}}    //
+                },                                  //
+            }                                       //
+        });                                         //
+        std::shared_ptr<Tensor> input2 = std::make_shared<Tensor>(Array4D<int,1,3,6,2> {
             {
-                {
-                    { 1, 2, 3 },
-                    { 10, 11, 12 },
-                    { 13, 14, 15 },
-                    { 1, 2, 3 }
+                {                                   //
+                    {{29, 56},{30, 57},{31, 58},{38, 65},{39, 66},{40, 67}},   //
+                    {{32, 59},{33, 60},{34, 61},{41, 68},{42, 69},{43, 70}},   //
+                    {{35, 62},{36, 63},{37, 64},{44, 71},{45, 72},{46, 73}}    //
                 },
-                {
-                    { 4, 5, 6 },
-                    { 16, 17, 18 },
-                    { 19, 20, 21 },
-                    { 4, 5, 6 }
-                }
             }
         });
 
-        std::shared_ptr<Node> myConcat = Concat(3, 1);
-        auto op = std::static_pointer_cast<OperatorTensor>(myConcat->getOperator());
-        op->associateInput(0,input1);
-        op->associateInput(1,input2);
-        op->associateInput(2,input1);
-        op->setDataType(DataType::Int32);
-        op->setBackend("cpu");
-        op->computeOutputDims();
+        std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<int,1,3,9,2> {
+            {                                                                                             //
+                {                                                                                         //
+                    {{20, 47},{21, 48},{22, 49},{29, 56},{30, 57},{31, 58},{38, 65},{39, 66},{40, 67}},   //
+                    {{23, 50},{24, 51},{25, 52},{32, 59},{33, 60},{34, 61},{41, 68},{42, 69},{43, 70}},   //
+                    {{26, 53},{27, 54},{28, 55},{35, 62},{36, 63},{37, 64},{44, 71},{45, 72},{46, 73}}    //
+                },                                                                                        //
+            }                                                                                             //
+        });                                                                                               //
+
+        auto myConcat = Concat(2, 2);
+        myConcat->getOperator()->associateInput(0, input1);
+        myConcat->getOperator()->associateInput(1, input2);
+        myConcat->getOperator()->setBackend("cpu");
+        myConcat->getOperator()->setDataType(DataType::Int32);
+        std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->computeOutputDims();
         myConcat->forward();
 
-        REQUIRE(*(op->getOutput(0)) == *expectedOutput);
+        std::static_pointer_cast<Tensor>(myConcat->getOperator()->getRawOutput(0))->print();
+
+        REQUIRE(*std::static_pointer_cast<OperatorTensor>(myConcat->getOperator())->getOutput(0) == *expectedOutput);
     }
 }
\ No newline at end of file
-- 
GitLab