From 37a9cafe87df52ca15cbe1afc3538032233df7f5 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Fri, 5 Jan 2024 16:13:00 +0000
Subject: [PATCH] Update Operators with device function

---
 include/aidge/operator/Erf.hpp        |  7 ++-----
 include/aidge/operator/Gather.hpp     | 10 +++-------
 include/aidge/operator/ReduceMean.hpp |  9 +++------
 include/aidge/operator/Reshape.hpp    |  7 ++-----
 include/aidge/operator/Transpose.hpp  |  7 ++-----
 5 files changed, 12 insertions(+), 28 deletions(-)

diff --git a/include/aidge/operator/Erf.hpp b/include/aidge/operator/Erf.hpp
index 6395756f3..6995cea5e 100644
--- a/include/aidge/operator/Erf.hpp
+++ b/include/aidge/operator/Erf.hpp
@@ -51,12 +51,9 @@ public:
         return std::make_shared<Erf_Op>(*this);
     }
 
-    void setBackend(const std::string& name) override {
+    void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
         mImpl = Registrar<Erf_Op>::create(name)(*this);
-        mOutputs[0]->setBackend(name);
-
-        // FIXME: temporary workaround
-        getInput(0)->setBackend(name);
+        mOutputs[0]->setBackend(name, device);
     }
 
     static const std::vector<std::string> getInputsName(){
diff --git a/include/aidge/operator/Gather.hpp b/include/aidge/operator/Gather.hpp
index f82762228..20082eed2 100644
--- a/include/aidge/operator/Gather.hpp
+++ b/include/aidge/operator/Gather.hpp
@@ -40,7 +40,7 @@ public:
 
     Gather_Op() = delete;
 
-    
+
     using Attributes_ = StaticAttributes<GatherAttr, int>;
     template <GatherAttr e> using attr = typename Attributes_::template attr<e>;
     Gather_Op(int axis)
@@ -70,13 +70,9 @@ public:
 
     void computeOutputDims() override final;
 
-    void setBackend(const std::string& name) override {
+    void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
         mImpl = Registrar<Gather_Op>::create(name)(*this);
-        mOutputs[0]->setBackend(name);
-
-        // FIXME: temporary workaround
-        getInput(0)->setBackend(name);
-        getInput(1)->setBackend(name);
+        mOutputs[0]->setBackend(name, device);
     }
 
     static const std::vector<std::string> getInputsName(){
diff --git a/include/aidge/operator/ReduceMean.hpp b/include/aidge/operator/ReduceMean.hpp
index 0acd21b28..52d011874 100644
--- a/include/aidge/operator/ReduceMean.hpp
+++ b/include/aidge/operator/ReduceMean.hpp
@@ -89,7 +89,7 @@ class ReduceMean_Op : public OperatorTensor,
                 }
                 else
                     outDims.push_back(getInput(0)->dims()[d]);
-            }        
+            }
             if(outDims.size()>0)
                 mOutputs[0]->resize(outDims);
             else
@@ -97,12 +97,9 @@ class ReduceMean_Op : public OperatorTensor,
         }
     }
 
-    void setBackend(const std::string &name) override {
+    void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
         mImpl = Registrar<ReduceMean_Op<DIM>>::create(name)(*this);
-        mOutputs[0]->setBackend(name);
-
-        // FIXME: temporary workaround
-        getInput(0)->setBackend(name);
+        mOutputs[0]->setBackend(name, device);
     }
 
     static const std::vector<std::string> getInputsName(){
diff --git a/include/aidge/operator/Reshape.hpp b/include/aidge/operator/Reshape.hpp
index 1ffa04596..32d71d5ad 100644
--- a/include/aidge/operator/Reshape.hpp
+++ b/include/aidge/operator/Reshape.hpp
@@ -66,12 +66,9 @@ public:
 
     void computeOutputDims() override final;
 
-    void setBackend(const std::string& name) override {
+    void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
         mImpl = Registrar<Reshape_Op>::create(name)(*this);
-        mOutputs[0]->setBackend(name);
-
-        // FIXME: temporary workaround
-        getInput(0)->setBackend(name);
+        mOutputs[0]->setBackend(name, device);
     }
 
     static const std::vector<std::string> getInputsName(){
diff --git a/include/aidge/operator/Transpose.hpp b/include/aidge/operator/Transpose.hpp
index f111be76c..2262bec14 100644
--- a/include/aidge/operator/Transpose.hpp
+++ b/include/aidge/operator/Transpose.hpp
@@ -79,12 +79,9 @@ class Transpose_Op : public OperatorTensor,
         }
     }
 
-    void setBackend(const std::string &name) override {
+    void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
         mImpl = Registrar<Transpose_Op<DIM>>::create(name)(*this);
-        mOutputs[0]->setBackend(name);
-
-        // FIXME: temporary workaround
-        getInput(0)->setBackend(name);
+        mOutputs[0]->setBackend(name, device);
     }
 
     static const std::vector<std::string> getInputsName(){
-- 
GitLab