From 48ecdca1fac1ed8be6f23793dfd9069978732cd3 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Wed, 10 Apr 2024 14:50:04 +0200
Subject: [PATCH] Added Gather default implementation

---
 include/aidge/operator/Gather.hpp | 17 +++++++++++----
 src/operator/Gather.cpp           | 36 ++++++++++++++++++++++++++++++-
 2 files changed, 48 insertions(+), 5 deletions(-)

diff --git a/include/aidge/operator/Gather.hpp b/include/aidge/operator/Gather.hpp
index 7101a2f19..feb2474b0 100644
--- a/include/aidge/operator/Gather.hpp
+++ b/include/aidge/operator/Gather.hpp
@@ -25,6 +25,12 @@
 #include "aidge/utils/Types.h"
 
 namespace Aidge {
+class Gather_OpImpl : public OperatorImpl {
+public:
+    Gather_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {}
+    void forward() override;
+};
+
 enum class GatherAttr { Indices, GatheredShape, Axis };
 
 class Gather_Op : public OperatorTensor,
@@ -46,7 +52,9 @@ public:
                 attr<GatherAttr::Indices>(indices),
                 attr<GatherAttr::GatheredShape>(gatheredShape),
                 attr<GatherAttr::Axis>(axis))
-    {}
+    {
+        mImpl = std::make_shared<Gather_OpImpl>(*this);
+    }
 
     /**
      * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
@@ -56,10 +64,11 @@ public:
         : OperatorTensor(op),
           Attributes_(op)
     {
-        if (op.mImpl){
+        if (!op.backend().empty()) {
             SET_IMPL_MACRO(Gather_Op, *this, op.backend());
-        } else {
-            mImpl = nullptr;
+        }
+        else {
+            mImpl = std::make_shared<Gather_OpImpl>(*this);
         }
     }
 
diff --git a/src/operator/Gather.cpp b/src/operator/Gather.cpp
index 082df8473..3b53aa5a2 100644
--- a/src/operator/Gather.cpp
+++ b/src/operator/Gather.cpp
@@ -20,6 +20,35 @@
 #include "aidge/utils/Types.h"
 #include "aidge/utils/ErrorHandling.hpp"
 
+void Aidge::Gather_OpImpl::forward() {
+    const Gather_Op& op = dynamic_cast<const Gather_Op&>(mOp);
+    const auto axis = op.template getAttr<std::int64_t>("Axis");
+
+    const std::size_t axisIdx = axis>=0 ?
+                                axis :
+                                static_cast<std::size_t>(axis) + op.getInput(0)->dims().size();
+
+    std::size_t postAxisElems = 1;
+    for (std::size_t i = axisIdx + 1; i < op.getInput(0)->dims().size(); ++i) {
+        postAxisElems *= op.getInput(0)->dims()[i];
+    }
+    std::size_t preAxisElems = 1;
+    for (std::size_t i = 0; i < axisIdx; ++i) {
+        preAxisElems *= op.getInput(0)->dims()[i];
+    }
+
+    const auto indices = op.template getAttr<std::vector<std::int64_t>>("Indices");
+    std::size_t outputOffset = 0;
+    for (std::size_t i=0; i<preAxisElems; ++i)
+    {
+        for(std::size_t j=0; j<indices.size(); ++j)
+        {
+            const std::size_t idx = indices[j] >= 0 ? indices[j] : static_cast<std::size_t>(indices[j]) + op.getInput(0)->dims()[axisIdx];
+            op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(i * postAxisElems * op.getInput(0)->dims()[axisIdx] + idx * postAxisElems), postAxisElems, outputOffset);
+            outputOffset += postAxisElems;
+        }
+    }
+}
 
 const std::string Aidge::Gather_Op::Type = "Gather";
 
@@ -53,6 +82,11 @@ bool Aidge::Gather_Op::computeOutputDims(bool /*allowDataDependency*/) {
 }
 
 void Aidge::Gather_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
-    SET_IMPL_MACRO(Gather_Op, *this, name);
+    if (Registrar<Gather_Op>::exists({name})) {
+        SET_IMPL_MACRO(Gather_Op, *this, name);
+    }
+    else {
+        mImpl = std::make_shared<Gather_OpImpl>(*this);
+    }
     mOutputs[0]->setBackend(name, device);
 }
-- 
GitLab