From f1d899733784d2de2d9f354d15aba571e9b5d989 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Fri, 8 Dec 2023 22:48:59 +0100
Subject: [PATCH] Removed temporary workaround

---
 include/aidge/operator/Add.hpp        |  5 -----
 include/aidge/operator/AvgPooling.hpp |  3 ---
 include/aidge/operator/BatchNorm.hpp  | 12 +++++++++++-
 include/aidge/operator/Concat.hpp     |  5 -----
 include/aidge/operator/Div.hpp        |  4 ----
 include/aidge/operator/LeakyReLU.hpp  |  3 ---
 include/aidge/operator/MatMul.hpp     |  4 ----
 include/aidge/operator/MaxPooling.hpp |  3 ---
 include/aidge/operator/Mul.hpp        |  4 ----
 include/aidge/operator/Pad.hpp        |  3 ---
 include/aidge/operator/Pow.hpp        |  4 ----
 include/aidge/operator/ReLU.hpp       |  3 ---
 include/aidge/operator/Scaling.hpp    |  2 --
 include/aidge/operator/Slice.hpp      |  3 ---
 include/aidge/operator/Softmax.hpp    |  3 ---
 include/aidge/operator/Sqrt.hpp       |  3 ---
 include/aidge/operator/Sub.hpp        |  4 ----
 src/operator/OperatorTensor.cpp       | 10 ----------
 18 files changed, 11 insertions(+), 67 deletions(-)

diff --git a/include/aidge/operator/Add.hpp b/include/aidge/operator/Add.hpp
index 7859ec91b..4e9118268 100644
--- a/include/aidge/operator/Add.hpp
+++ b/include/aidge/operator/Add.hpp
@@ -79,11 +79,6 @@ public:
     void setBackend(const std::string& name, int device = 0) override {
         mImpl = Registrar<Add_Op>::create(name)(*this);
         mOutputs[0]->setBackend(name, device);
-
-        // FIXME: temporary workaround
-        for (std::size_t i = 0; i < nbInputs(); ++i) {
-            getInput(i)->setBackend(name, device);
-        }
     }
 
     static const std::vector<std::string> getInputsName(){
diff --git a/include/aidge/operator/AvgPooling.hpp b/include/aidge/operator/AvgPooling.hpp
index 3a1303c84..e3202f76b 100644
--- a/include/aidge/operator/AvgPooling.hpp
+++ b/include/aidge/operator/AvgPooling.hpp
@@ -137,9 +137,6 @@ public:
     void setBackend(const std::string &name, int device = 0) override {
         mImpl = Registrar<AvgPooling_Op<DIM>>::create(name)(*this);
         mOutputs[0]->setBackend(name, device);
-
-        // FIXME: temporary workaround
-        getInput(0)->setBackend(name, device);
     }
 
     static const std::vector<std::string> getInputsName(){
diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp
index a14c4440f..e9e1c7704 100644
--- a/include/aidge/operator/BatchNorm.hpp
+++ b/include/aidge/operator/BatchNorm.hpp
@@ -98,13 +98,23 @@ public:
         mImpl = Registrar<BatchNorm_Op<DIM>>::create(name)(*this);
         mOutputs[0]->setBackend(name, device);
 
-        // FIXME: temporary workaround
+        // By default, automatically set backend for scale, shift, mean and variance
         getInput(1)->setBackend(name, device);
         getInput(2)->setBackend(name, device);
         getInput(3)->setBackend(name, device);
         getInput(4)->setBackend(name, device);
     }
 
+    void setDataType(const DataType& dt) const override {
+        mOutputs[0]->setDataType(dt);
+
+        // By default, automatically set data type for scale, shift, mean and variance
+        getInput(1)->setDataType(dt);
+        getInput(2)->setDataType(dt);
+        getInput(3)->setDataType(dt);
+        getInput(4)->setDataType(dt);
+    }
+
     static const std::vector<std::string> getInputsName() {
         return {"data_input", "scale", "shift", "mean", "variance"};
     }
diff --git a/include/aidge/operator/Concat.hpp b/include/aidge/operator/Concat.hpp
index 465ff0508..9d0eed63d 100644
--- a/include/aidge/operator/Concat.hpp
+++ b/include/aidge/operator/Concat.hpp
@@ -104,11 +104,6 @@ public:
     void setBackend(const std::string& name, int device = 0) override {
         mImpl = Registrar<Concat_Op>::create(name)(*this);
         mOutputs[0]->setBackend(name, device);
-
-        // FIXME: temporary workaround
-        for (std::size_t i = 0; i < nbInputs(); ++i) {
-            getInput(i)->setBackend(name, device);
-        }
     }
 
     static const std::vector<std::string> getInputsName(){
diff --git a/include/aidge/operator/Div.hpp b/include/aidge/operator/Div.hpp
index 62d174d4c..5bb0efd80 100644
--- a/include/aidge/operator/Div.hpp
+++ b/include/aidge/operator/Div.hpp
@@ -57,10 +57,6 @@ public:
     void setBackend(const std::string& name, int device = 0) override {
         mImpl = Registrar<Div_Op>::create(name)(*this);
         mOutputs[0]->setBackend(name, device);
-
-        // FIXME: temporary workaround
-        getInput(0)->setBackend(name, device);
-        getInput(1)->setBackend(name, device);
     }
 
     static const std::vector<std::string> getInputsName(){
diff --git a/include/aidge/operator/LeakyReLU.hpp b/include/aidge/operator/LeakyReLU.hpp
index 93a9958e0..a8bc0a477 100644
--- a/include/aidge/operator/LeakyReLU.hpp
+++ b/include/aidge/operator/LeakyReLU.hpp
@@ -70,9 +70,6 @@ public:
     void setBackend(const std::string& name, int device = 0) override {
         mImpl = Registrar<LeakyReLU_Op>::create(name)(*this);
         mOutputs[0]->setBackend(name, device);
-
-        // FIXME: temporary workaround
-        getInput(0)->setBackend(name, device);
     }
 
     static const std::vector<std::string> getInputsName(){
diff --git a/include/aidge/operator/MatMul.hpp b/include/aidge/operator/MatMul.hpp
index ed077b214..3d1ed900a 100644
--- a/include/aidge/operator/MatMul.hpp
+++ b/include/aidge/operator/MatMul.hpp
@@ -86,10 +86,6 @@ public:
     void setBackend(const std::string& name, int device = 0) override {
         mImpl = Registrar<MatMul_Op>::create(name)(*this);
         mOutputs[0]->setBackend(name, device);
-
-        // FIXME: temporary workaround
-        getInput(0)->setBackend(name, device);
-        getInput(1)->setBackend(name, device);
     }
 
     static const std::vector<std::string> getInputsName(){
diff --git a/include/aidge/operator/MaxPooling.hpp b/include/aidge/operator/MaxPooling.hpp
index 76ee8cb83..449437180 100644
--- a/include/aidge/operator/MaxPooling.hpp
+++ b/include/aidge/operator/MaxPooling.hpp
@@ -107,9 +107,6 @@ public:
     void setBackend(const std::string &name, int device = 0) override {
         mImpl = Registrar<MaxPooling_Op<DIM>>::create(name)(*this);
         mOutputs[0]->setBackend(name, device);
-
-        // FIXME: temporary workaround
-        getInput(0)->setBackend(name, device);
     }
 
     static const std::vector<std::string> getInputsName(){
diff --git a/include/aidge/operator/Mul.hpp b/include/aidge/operator/Mul.hpp
index 86365f345..41049c43f 100644
--- a/include/aidge/operator/Mul.hpp
+++ b/include/aidge/operator/Mul.hpp
@@ -59,10 +59,6 @@ public:
     void setBackend(const std::string& name, int device = 0) override {
         mImpl = Registrar<Mul_Op>::create(name)(*this);
         mOutputs[0]->setBackend(name, device);
-
-        // FIXME: temporary workaround
-        getInput(0)->setBackend(name, device);
-        getInput(1)->setBackend(name, device);
     }
 
     static const std::vector<std::string> getInputsName(){
diff --git a/include/aidge/operator/Pad.hpp b/include/aidge/operator/Pad.hpp
index ab54853a3..9f49cb9a9 100644
--- a/include/aidge/operator/Pad.hpp
+++ b/include/aidge/operator/Pad.hpp
@@ -100,9 +100,6 @@ public:
     void setBackend(const std::string &name, int device = 0) override {
         mImpl = Registrar<Pad_Op<DIM>>::create(name)(*this);
         mOutputs[0]->setBackend(name, device);
-
-        // FIXME: temporary workaround
-        getInput(0)->setBackend(name, device);
     }
 
     static const std::vector<std::string> getInputsName(){
diff --git a/include/aidge/operator/Pow.hpp b/include/aidge/operator/Pow.hpp
index a36f8ee9f..464c49909 100644
--- a/include/aidge/operator/Pow.hpp
+++ b/include/aidge/operator/Pow.hpp
@@ -57,10 +57,6 @@ public:
     void setBackend(const std::string& name, int device = 0) override {
         mImpl = Registrar<Pow_Op>::create(name)(*this);
         mOutputs[0]->setBackend(name, device);
-
-        // FIXME: temporary workaround
-        getInput(0)->setBackend(name, device);
-        getInput(1)->setBackend(name, device);
     }
 
     static const std::vector<std::string> getInputsName(){
diff --git a/include/aidge/operator/ReLU.hpp b/include/aidge/operator/ReLU.hpp
index 41dd24031..8a8f3f854 100644
--- a/include/aidge/operator/ReLU.hpp
+++ b/include/aidge/operator/ReLU.hpp
@@ -54,9 +54,6 @@ public:
     void setBackend(const std::string& name, int device = 0) override {
         mImpl = Registrar<ReLU_Op>::create(name)(*this);
         mOutputs[0]->setBackend(name, device);
-
-        // FIXME: temporary workaround
-        getInput(0)->setBackend(name, device);
     }
 
     static const std::vector<std::string> getInputsName(){
diff --git a/include/aidge/operator/Scaling.hpp b/include/aidge/operator/Scaling.hpp
index 0b1f710b1..6c49b7848 100644
--- a/include/aidge/operator/Scaling.hpp
+++ b/include/aidge/operator/Scaling.hpp
@@ -69,8 +69,6 @@ public:
     void setBackend(const std::string& name, int device = 0) override {
         mImpl = Registrar<Scaling_Op>::create(name)(*this);
         mOutputs[0]->setBackend(name, device);
-        // FIXME: temporary workaround
-        mInputs[0]->setBackend(name, device);
     }
 
     static const std::vector<std::string> getInputsName() {
diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp
index 604f0e141..e48de88c8 100644
--- a/include/aidge/operator/Slice.hpp
+++ b/include/aidge/operator/Slice.hpp
@@ -93,9 +93,6 @@ public:
     void setBackend(const std::string &name, int device = 0) override {
         mImpl = Registrar<Slice_Op>::create(name)(*this);
         mOutputs[0]->setBackend(name, device);
-
-        // FIXME: temporary workaround
-        getInput(0)->setBackend(name, device);
     }
 
     static const std::vector<std::string> getInputsName(){
diff --git a/include/aidge/operator/Softmax.hpp b/include/aidge/operator/Softmax.hpp
index 48e4aa323..d4716dc5c 100644
--- a/include/aidge/operator/Softmax.hpp
+++ b/include/aidge/operator/Softmax.hpp
@@ -54,9 +54,6 @@ public:
     void setBackend(const std::string& name, int device = 0) override {
         mImpl = Registrar<Softmax_Op>::create(name)(*this);
         mOutputs[0]->setBackend(name, device);
-
-        // FIXME: temporary workaround
-        getInput(0)->setBackend(name, device);
     }
 
     static const std::vector<std::string> getInputsName(){
diff --git a/include/aidge/operator/Sqrt.hpp b/include/aidge/operator/Sqrt.hpp
index 9db8c3753..b679b3d6e 100644
--- a/include/aidge/operator/Sqrt.hpp
+++ b/include/aidge/operator/Sqrt.hpp
@@ -59,9 +59,6 @@ public:
     void setBackend(const std::string& name, int device = 0) override {
         mImpl = Registrar<Sqrt_Op>::create(name)(*this);
         mOutputs[0]->setBackend(name, device);
-
-        // FIXME: temporary workaround
-        getInput(0)->setBackend(name, device);
     }
 
     static const std::vector<std::string> getInputsName(){
diff --git a/include/aidge/operator/Sub.hpp b/include/aidge/operator/Sub.hpp
index d55a29eee..7eb7c8f9b 100644
--- a/include/aidge/operator/Sub.hpp
+++ b/include/aidge/operator/Sub.hpp
@@ -62,10 +62,6 @@ public:
     void setBackend(const std::string& name, int device = 0) override {
         mImpl = Registrar<Sub_Op>::create(name)(*this);
         mOutputs[0]->setBackend(name, device);
-
-        // FIXME: temporary workaround
-        getInput(0)->setBackend(name, device);
-        getInput(1)->setBackend(name, device);
     }
 
     static const std::vector<std::string> getInputsName(){
diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp
index ccee3865e..21a479622 100644
--- a/src/operator/OperatorTensor.cpp
+++ b/src/operator/OperatorTensor.cpp
@@ -149,14 +149,4 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
     for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
         getOutput(i)->setDataType(dataType);
     }
-    /*
-    for (IOIndex_t i = 0; i < nbInputs(); ++i) {
-        if (!getInput(i)) {
-            AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not set");
-        }
-        else {
-            getInput(i)->setDataType(dataType);
-        }
-    }
-    */
 }
\ No newline at end of file
-- 
GitLab