From ce73448ad863cb35259bfa71371bed9b43721018 Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Tue, 23 Jan 2024 16:03:57 +0100
Subject: [PATCH] change gather input into attr

---
 .../aidge/backend/cpu/operator/GatherImpl.hpp |  4 ++--
 .../operator/GatherImpl_forward_kernels.hpp   | 19 +++++++++----------
 src/operator/GatherImpl.cpp                   |  7 +------
 unit_tests/operator/Test_GatherImpl.cpp       | 10 ++++++----
 4 files changed, 18 insertions(+), 22 deletions(-)

diff --git a/include/aidge/backend/cpu/operator/GatherImpl.hpp b/include/aidge/backend/cpu/operator/GatherImpl.hpp
index d22e484e..1d235ff1 100644
--- a/include/aidge/backend/cpu/operator/GatherImpl.hpp
+++ b/include/aidge/backend/cpu/operator/GatherImpl.hpp
@@ -24,10 +24,10 @@ namespace Aidge {
 
 // compute kernel registry for forward and backward
 class GatherImplForward_cpu
-    : public Registrable<GatherImplForward_cpu, std::tuple<DataType, DataType>, void(std::size_t, const std::vector<DimSize_t>&, const std::vector<DimSize_t>&, const void*, const void*, void*)> {
+    : public Registrable<GatherImplForward_cpu, std::tuple<DataType, DataType>, void(const typename Gather_Op::Attrs&, const std::vector<DimSize_t>&, const void*, void*)> {
 };
 class GatherImplBackward_cpu
-    : public Registrable<GatherImplBackward_cpu, std::tuple<DataType, DataType>, void(std::size_t, const std::vector<DimSize_t>&, const std::vector<DimSize_t>&, const void*, const void*, void*)> {
+    : public Registrable<GatherImplBackward_cpu, std::tuple<DataType, DataType>, void(const typename Gather_Op::Attrs&, const std::vector<DimSize_t>&, const void*, void*)> {
 };
 
 class GatherImpl_cpu : public OperatorImpl {
diff --git a/include/aidge/backend/cpu/operator/GatherImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/GatherImpl_forward_kernels.hpp
index 31119e27..591985e8 100644
--- a/include/aidge/backend/cpu/operator/GatherImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/GatherImpl_forward_kernels.hpp
@@ -22,12 +22,13 @@
 
 namespace Aidge {
 template <class I, class O>
-void GatherImpl_cpu_forward_kernel(std::size_t axisIdx, const std::vector<DimSize_t>& inputDims, const std::vector<DimSize_t>& indicesDims, const void* input_, const void* indexes_, void* output_)
+void GatherImpl_cpu_forward_kernel(const typename Gather_Op::Attrs& attrs, const std::vector<DimSize_t>& inputDims, const void* input_, void* output_)
 {
     const I* input = static_cast<const I*>(input_);
-    const int* indexes = static_cast<const int*>(indexes_);
     O* output = static_cast<O*>(output_);
 
+    std::size_t axisIdx = std::get<2>(attrs)>=0 ? std::get<2>(attrs) : static_cast<std::size_t>(std::get<2>(attrs)) + inputDims.size();
+
     std::size_t postAxisElems = 1;
     for (std::size_t i = axisIdx + 1; i < inputDims.size(); ++i) {
         postAxisElems *= inputDims[i];
@@ -37,17 +38,15 @@ void GatherImpl_cpu_forward_kernel(std::size_t axisIdx, const std::vector<DimSiz
         preAxisElems *= inputDims[i];
     }
 
+    std::vector<std::int64_t> indices = std::get<0>(attrs);
     for (std::size_t i=0; i<preAxisElems; ++i)
     {
-        for(std::size_t idxRow=0; idxRow<indicesDims[0]; ++idxRow)
+        for(std::size_t j=0; j<indices.size(); ++j)
         {
-            for(std::size_t idxCol=0; idxCol<indicesDims[1]; ++idxCol)
-            {
-                std::size_t idx = indexes[indicesDims[1] *  idxRow + idxCol];
-                const I* startPtr = std::next(input, i * postAxisElems * inputDims[axisIdx] + idx * postAxisElems);
-                std::copy_n(startPtr, postAxisElems, output);
-                output += postAxisElems;
-            }
+            std::size_t idx = indices[j] >= 0 ? indices[j] : indices[j] + inputDims[axisIdx];
+            const I* startPtr = std::next(input, i * postAxisElems * inputDims[axisIdx] + idx * postAxisElems);
+            std::copy_n(startPtr, postAxisElems, output);
+            output += postAxisElems;
         }
     }
 }
diff --git a/src/operator/GatherImpl.cpp b/src/operator/GatherImpl.cpp
index fd5e755b..ce98627d 100644
--- a/src/operator/GatherImpl.cpp
+++ b/src/operator/GatherImpl.cpp
@@ -27,19 +27,14 @@ Aidge::NbElts_t Aidge::GatherImpl_cpu::getNbRequiredProtected(const Aidge::IOInd
 }
 
 void Aidge::GatherImpl_cpu::forward() {
-    Gather_Op::Attrs attr = dynamic_cast<const Gather_Op&>(mOp).getStaticAttributes();
-    const int& axisIdx = static_cast<const int&>(std::get<0>(attr));
-    assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->nbDims() > 1);// > axisIdx && "input dim must be bigger than "+std::to_strint(axisIdx)
 
     auto kernelFunc = Registrar<GatherImplForward_cpu>::create({
         std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
         std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
 
     // Call kernel
-    kernelFunc(axisIdx,
+    kernelFunc(dynamic_cast<const Gather_Op&>(mOp).getStaticAttributes(),
         std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(),
-        std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims(),
         std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
-        std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->getImpl()->rawPtr(),
         std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
 }
diff --git a/unit_tests/operator/Test_GatherImpl.cpp b/unit_tests/operator/Test_GatherImpl.cpp
index e0903aa7..a8345917 100644
--- a/unit_tests/operator/Test_GatherImpl.cpp
+++ b/unit_tests/operator/Test_GatherImpl.cpp
@@ -44,14 +44,16 @@ TEST_CASE("[cpu/operator] Gather(forward)") {
             }
         });
 
-        std::shared_ptr<Node> myGather = Gather(0);
+        std::shared_ptr<Node> myGather = Gather({1, 2}, {1, 2}, 0);
         auto op = std::static_pointer_cast<OperatorTensor>(myGather -> getOperator());
         op->associateInput(0,input);
-        op->associateInput(1,indexes);
+        // op->associateInput(1,indexes);
         op->setDataType(DataType::Int32);
         op->setBackend("cpu");
         op->computeOutputDims();
         myGather->forward();
+        op->getOutput(0)->print();
+        expectedOutput->print();
 
         REQUIRE(*(op->getOutput(0)) == *expectedOutput);
 
@@ -83,10 +85,10 @@ TEST_CASE("[cpu/operator] Gather(forward)") {
             }
         });
 
-        std::shared_ptr<Node> myGather = Gather(1);
+        std::shared_ptr<Node> myGather = Gather({0, 2}, {1, 2}, 1);
         auto op = std::static_pointer_cast<OperatorTensor>(myGather -> getOperator());
         op->associateInput(0,input);
-        op->associateInput(1,indexes);
+        // op->associateInput(1,indexes);
         op->setDataType(DataType::Int32);
         op->setBackend("cpu");
         op->computeOutputDims();
-- 
GitLab