From 414f663267fb29ae30d7789d6038a2b06ee70f68 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 5 Jan 2024 15:27:37 +0100 Subject: [PATCH] Make OperatorTensor::setDataType() more generic --- include/aidge/operator/BatchNorm.hpp | 10 ---------- include/aidge/operator/Conv.hpp | 8 -------- include/aidge/operator/ConvDepthWise.hpp | 8 -------- include/aidge/operator/FC.hpp | 8 -------- src/operator/OperatorTensor.cpp | 4 ++++ 5 files changed, 4 insertions(+), 34 deletions(-) diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp index d3e0fceab..4a0f40c03 100644 --- a/include/aidge/operator/BatchNorm.hpp +++ b/include/aidge/operator/BatchNorm.hpp @@ -105,16 +105,6 @@ public: getInput(4)->setBackend(name, device); } - void setDataType(const DataType& dt) const override { - mOutputs[0]->setDataType(dt); - - // By default, automatically set data type for scale, shift, mean and variance - getInput(1)->setDataType(dt); - getInput(2)->setDataType(dt); - getInput(3)->setDataType(dt); - getInput(4)->setDataType(dt); - } - static const std::vector<std::string> getInputsName() { return {"data_input", "scale", "shift", "mean", "variance"}; } diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index 6585c2d30..8ddb63964 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -180,14 +180,6 @@ std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveFiel getInput(2)->setBackend(name, device); } - void setDataType(const DataType& dt) const override { - mOutputs[0]->setDataType(dt); - - // By default, automatically set data type for weight and bias inputs - getInput(1)->setDataType(dt); - getInput(2)->setDataType(dt); - } - static const std::vector<std::string> getInputsName(){ return {"data_input", "weight", "bias"}; } diff --git a/include/aidge/operator/ConvDepthWise.hpp b/include/aidge/operator/ConvDepthWise.hpp index 839a0ec79..d7ef9ff11 100644 --- a/include/aidge/operator/ConvDepthWise.hpp +++ b/include/aidge/operator/ConvDepthWise.hpp @@ -174,14 +174,6 @@ public: getInput(2)->setBackend(name, device); } - void setDataType(const DataType& dt) const override { - mOutputs[0]->setDataType(dt); - - // By default, automatically set data type for weight and bias inputs - getInput(1)->setDataType(dt); - getInput(2)->setDataType(dt); - } - static const std::vector<std::string> getInputsName(){ return {"data_input", "weight", "bias"}; } diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index 36ff7106c..a73734ad2 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -104,14 +104,6 @@ public: getInput(2)->setBackend(name, device); } - void setDataType(const DataType& dt) const override { - mOutputs[0]->setDataType(dt); - - // By default, automatically set data type for weight and bias inputs - getInput(1)->setDataType(dt); - getInput(2)->setDataType(dt); - } - static const std::vector<std::string> getInputsName(){ return {"data_input", "weight", "bias"}; } diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp index 21a479622..5dae6009a 100644 --- a/src/operator/OperatorTensor.cpp +++ b/src/operator/OperatorTensor.cpp @@ -149,4 +149,8 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const { for (IOIndex_t i = 0; i < nbOutputs(); ++i) { getOutput(i)->setDataType(dataType); } + + for (IOIndex_t i = nbData(); i < nbInputs(); ++i) { + getInput(i)->setDataType(dataType); + } } \ No newline at end of file -- GitLab