diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp index d3e0fceabd40d0ebf3c4521bd3010e1d1538b89f..4a0f40c034c7738a33eb8a9569fac4aa2fff465d 100644 --- a/include/aidge/operator/BatchNorm.hpp +++ b/include/aidge/operator/BatchNorm.hpp @@ -105,16 +105,6 @@ public: getInput(4)->setBackend(name, device); } - void setDataType(const DataType& dt) const override { - mOutputs[0]->setDataType(dt); - - // By default, automatically set data type for scale, shift, mean and variance - getInput(1)->setDataType(dt); - getInput(2)->setDataType(dt); - getInput(3)->setDataType(dt); - getInput(4)->setDataType(dt); - } - static const std::vector<std::string> getInputsName() { return {"data_input", "scale", "shift", "mean", "variance"}; } diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index 6585c2d306b1a36544a36681ef7bd4b2e9d1b9b8..8ddb6396407ce09b5c5f24360ac9d7298e6232e3 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -180,14 +180,6 @@ std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveFiel getInput(2)->setBackend(name, device); } - void setDataType(const DataType& dt) const override { - mOutputs[0]->setDataType(dt); - - // By default, automatically set data type for weight and bias inputs - getInput(1)->setDataType(dt); - getInput(2)->setDataType(dt); - } - static const std::vector<std::string> getInputsName(){ return {"data_input", "weight", "bias"}; } diff --git a/include/aidge/operator/ConvDepthWise.hpp b/include/aidge/operator/ConvDepthWise.hpp index 839a0ec793139a343d2f279dba283df9a81e88b2..d7ef9ff119a24f24f93b11e1f10b1076141f9a74 100644 --- a/include/aidge/operator/ConvDepthWise.hpp +++ b/include/aidge/operator/ConvDepthWise.hpp @@ -174,14 +174,6 @@ public: getInput(2)->setBackend(name, device); } - void setDataType(const DataType& dt) const override { - mOutputs[0]->setDataType(dt); - - // By default, automatically set data type for weight and bias inputs - getInput(1)->setDataType(dt); - getInput(2)->setDataType(dt); - } - static const std::vector<std::string> getInputsName(){ return {"data_input", "weight", "bias"}; } diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index 36ff7106cd3287ea45c743aace16cbc79f63820b..a73734ad20e10fe2a3e1d0d12d40e584b4540fb4 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -104,14 +104,6 @@ public: getInput(2)->setBackend(name, device); } - void setDataType(const DataType& dt) const override { - mOutputs[0]->setDataType(dt); - - // By default, automatically set data type for weight and bias inputs - getInput(1)->setDataType(dt); - getInput(2)->setDataType(dt); - } - static const std::vector<std::string> getInputsName(){ return {"data_input", "weight", "bias"}; } diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp index 21a47962228949c1ae4256b4d9ef053fbf50ce76..5dae6009ae09042876782965ca52402f198b2a29 100644 --- a/src/operator/OperatorTensor.cpp +++ b/src/operator/OperatorTensor.cpp @@ -149,4 +149,8 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const { for (IOIndex_t i = 0; i < nbOutputs(); ++i) { getOutput(i)->setDataType(dataType); } + + for (IOIndex_t i = nbData(); i < nbInputs(); ++i) { + getInput(i)->setDataType(dataType); + } } \ No newline at end of file