From 414f663267fb29ae30d7789d6038a2b06ee70f68 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Fri, 5 Jan 2024 15:27:37 +0100
Subject: [PATCH] Make OperatorTensor::setDataType() more generic

---
 include/aidge/operator/BatchNorm.hpp     | 10 ----------
 include/aidge/operator/Conv.hpp          |  8 --------
 include/aidge/operator/ConvDepthWise.hpp |  8 --------
 include/aidge/operator/FC.hpp            |  8 --------
 src/operator/OperatorTensor.cpp          |  4 ++++
 5 files changed, 4 insertions(+), 34 deletions(-)

diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp
index d3e0fceab..4a0f40c03 100644
--- a/include/aidge/operator/BatchNorm.hpp
+++ b/include/aidge/operator/BatchNorm.hpp
@@ -105,16 +105,6 @@ public:
         getInput(4)->setBackend(name, device);
     }
 
-    void setDataType(const DataType& dt) const override {
-        mOutputs[0]->setDataType(dt);
-
-        // By default, automatically set data type for scale, shift, mean and variance
-        getInput(1)->setDataType(dt);
-        getInput(2)->setDataType(dt);
-        getInput(3)->setDataType(dt);
-        getInput(4)->setDataType(dt);
-    }
-
     static const std::vector<std::string> getInputsName() {
         return {"data_input", "scale", "shift", "mean", "variance"};
     }
diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp
index 6585c2d30..8ddb63964 100644
--- a/include/aidge/operator/Conv.hpp
+++ b/include/aidge/operator/Conv.hpp
@@ -180,14 +180,6 @@ std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveFiel
         getInput(2)->setBackend(name, device);
     }
 
-    void setDataType(const DataType& dt) const override {
-        mOutputs[0]->setDataType(dt);
-
-        // By default, automatically set data type for weight and bias inputs
-        getInput(1)->setDataType(dt);
-        getInput(2)->setDataType(dt);
-    }
-
     static const std::vector<std::string> getInputsName(){
         return {"data_input", "weight", "bias"};
     }
diff --git a/include/aidge/operator/ConvDepthWise.hpp b/include/aidge/operator/ConvDepthWise.hpp
index 839a0ec79..d7ef9ff11 100644
--- a/include/aidge/operator/ConvDepthWise.hpp
+++ b/include/aidge/operator/ConvDepthWise.hpp
@@ -174,14 +174,6 @@ public:
         getInput(2)->setBackend(name, device);
     }
 
-    void setDataType(const DataType& dt) const override {
-        mOutputs[0]->setDataType(dt);
-
-        // By default, automatically set data type for weight and bias inputs
-        getInput(1)->setDataType(dt);
-        getInput(2)->setDataType(dt);
-    }
-
     static const std::vector<std::string> getInputsName(){
         return {"data_input", "weight", "bias"};
     }
diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp
index 36ff7106c..a73734ad2 100644
--- a/include/aidge/operator/FC.hpp
+++ b/include/aidge/operator/FC.hpp
@@ -104,14 +104,6 @@ public:
         getInput(2)->setBackend(name, device);
     }
 
-    void setDataType(const DataType& dt) const override {
-        mOutputs[0]->setDataType(dt);
-
-        // By default, automatically set data type for weight and bias inputs
-        getInput(1)->setDataType(dt);
-        getInput(2)->setDataType(dt);
-    }
-
     static const std::vector<std::string> getInputsName(){
         return {"data_input", "weight", "bias"};
     }
diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp
index 21a479622..5dae6009a 100644
--- a/src/operator/OperatorTensor.cpp
+++ b/src/operator/OperatorTensor.cpp
@@ -149,4 +149,8 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
     for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
         getOutput(i)->setDataType(dataType);
     }
+
+    for (IOIndex_t i = nbData(); i < nbInputs(); ++i) {
+        getInput(i)->setDataType(dataType);
+    }
 }
\ No newline at end of file
-- 
GitLab