From f1d899733784d2de2d9f354d15aba571e9b5d989 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 8 Dec 2023 22:48:59 +0100 Subject: [PATCH] Removed temporary workaround --- include/aidge/operator/Add.hpp | 5 ----- include/aidge/operator/AvgPooling.hpp | 3 --- include/aidge/operator/BatchNorm.hpp | 12 +++++++++++- include/aidge/operator/Concat.hpp | 5 ----- include/aidge/operator/Div.hpp | 4 ---- include/aidge/operator/LeakyReLU.hpp | 3 --- include/aidge/operator/MatMul.hpp | 4 ---- include/aidge/operator/MaxPooling.hpp | 3 --- include/aidge/operator/Mul.hpp | 4 ---- include/aidge/operator/Pad.hpp | 3 --- include/aidge/operator/Pow.hpp | 4 ---- include/aidge/operator/ReLU.hpp | 3 --- include/aidge/operator/Scaling.hpp | 2 -- include/aidge/operator/Slice.hpp | 3 --- include/aidge/operator/Softmax.hpp | 3 --- include/aidge/operator/Sqrt.hpp | 3 --- include/aidge/operator/Sub.hpp | 4 ---- src/operator/OperatorTensor.cpp | 10 ---------- 18 files changed, 11 insertions(+), 67 deletions(-) diff --git a/include/aidge/operator/Add.hpp b/include/aidge/operator/Add.hpp index 7859ec91b..4e9118268 100644 --- a/include/aidge/operator/Add.hpp +++ b/include/aidge/operator/Add.hpp @@ -79,11 +79,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<Add_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - for (std::size_t i = 0; i < nbInputs(); ++i) { - getInput(i)->setBackend(name, device); - } } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/AvgPooling.hpp b/include/aidge/operator/AvgPooling.hpp index 3a1303c84..e3202f76b 100644 --- a/include/aidge/operator/AvgPooling.hpp +++ b/include/aidge/operator/AvgPooling.hpp @@ -137,9 +137,6 @@ public: void setBackend(const std::string &name, int device = 0) override { mImpl = Registrar<AvgPooling_Op<DIM>>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp index a14c4440f..e9e1c7704 100644 --- a/include/aidge/operator/BatchNorm.hpp +++ b/include/aidge/operator/BatchNorm.hpp @@ -98,13 +98,23 @@ public: mImpl = Registrar<BatchNorm_Op<DIM>>::create(name)(*this); mOutputs[0]->setBackend(name, device); - // FIXME: temporary workaround + // By default, automatically set backend for scale, shift, mean and variance getInput(1)->setBackend(name, device); getInput(2)->setBackend(name, device); getInput(3)->setBackend(name, device); getInput(4)->setBackend(name, device); } + void setDataType(const DataType& dt) const override { + mOutputs[0]->setDataType(dt); + + // By default, automatically set data type for scale, shift, mean and variance + getInput(1)->setDataType(dt); + getInput(2)->setDataType(dt); + getInput(3)->setDataType(dt); + getInput(4)->setDataType(dt); + } + static const std::vector<std::string> getInputsName() { return {"data_input", "scale", "shift", "mean", "variance"}; } diff --git a/include/aidge/operator/Concat.hpp b/include/aidge/operator/Concat.hpp index 465ff0508..9d0eed63d 100644 --- a/include/aidge/operator/Concat.hpp +++ b/include/aidge/operator/Concat.hpp @@ -104,11 +104,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<Concat_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - for (std::size_t i = 0; i < nbInputs(); ++i) { - getInput(i)->setBackend(name, device); - } } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Div.hpp b/include/aidge/operator/Div.hpp index 62d174d4c..5bb0efd80 100644 --- a/include/aidge/operator/Div.hpp +++ b/include/aidge/operator/Div.hpp @@ -57,10 +57,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<Div_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); - getInput(1)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/LeakyReLU.hpp b/include/aidge/operator/LeakyReLU.hpp index 93a9958e0..a8bc0a477 100644 --- a/include/aidge/operator/LeakyReLU.hpp +++ b/include/aidge/operator/LeakyReLU.hpp @@ -70,9 +70,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<LeakyReLU_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/MatMul.hpp b/include/aidge/operator/MatMul.hpp index ed077b214..3d1ed900a 100644 --- a/include/aidge/operator/MatMul.hpp +++ b/include/aidge/operator/MatMul.hpp @@ -86,10 +86,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<MatMul_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); - getInput(1)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/MaxPooling.hpp b/include/aidge/operator/MaxPooling.hpp index 76ee8cb83..449437180 100644 --- a/include/aidge/operator/MaxPooling.hpp +++ b/include/aidge/operator/MaxPooling.hpp @@ -107,9 +107,6 @@ public: void setBackend(const std::string &name, int device = 0) override { mImpl = Registrar<MaxPooling_Op<DIM>>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Mul.hpp b/include/aidge/operator/Mul.hpp index 86365f345..41049c43f 100644 --- a/include/aidge/operator/Mul.hpp +++ b/include/aidge/operator/Mul.hpp @@ -59,10 +59,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<Mul_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); - getInput(1)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Pad.hpp b/include/aidge/operator/Pad.hpp index ab54853a3..9f49cb9a9 100644 --- a/include/aidge/operator/Pad.hpp +++ b/include/aidge/operator/Pad.hpp @@ -100,9 +100,6 @@ public: void setBackend(const std::string &name, int device = 0) override { mImpl = Registrar<Pad_Op<DIM>>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Pow.hpp b/include/aidge/operator/Pow.hpp index a36f8ee9f..464c49909 100644 --- a/include/aidge/operator/Pow.hpp +++ b/include/aidge/operator/Pow.hpp @@ -57,10 +57,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<Pow_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); - getInput(1)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/ReLU.hpp b/include/aidge/operator/ReLU.hpp index 41dd24031..8a8f3f854 100644 --- a/include/aidge/operator/ReLU.hpp +++ b/include/aidge/operator/ReLU.hpp @@ -54,9 +54,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<ReLU_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Scaling.hpp b/include/aidge/operator/Scaling.hpp index 0b1f710b1..6c49b7848 100644 --- a/include/aidge/operator/Scaling.hpp +++ b/include/aidge/operator/Scaling.hpp @@ -69,8 +69,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<Scaling_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - // FIXME: temporary workaround - mInputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName() { diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp index 604f0e141..e48de88c8 100644 --- a/include/aidge/operator/Slice.hpp +++ b/include/aidge/operator/Slice.hpp @@ -93,9 +93,6 @@ public: void setBackend(const std::string &name, int device = 0) override { mImpl = Registrar<Slice_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Softmax.hpp b/include/aidge/operator/Softmax.hpp index 48e4aa323..d4716dc5c 100644 --- a/include/aidge/operator/Softmax.hpp +++ b/include/aidge/operator/Softmax.hpp @@ -54,9 +54,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<Softmax_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Sqrt.hpp b/include/aidge/operator/Sqrt.hpp index 9db8c3753..b679b3d6e 100644 --- a/include/aidge/operator/Sqrt.hpp +++ b/include/aidge/operator/Sqrt.hpp @@ -59,9 +59,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<Sqrt_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Sub.hpp b/include/aidge/operator/Sub.hpp index d55a29eee..7eb7c8f9b 100644 --- a/include/aidge/operator/Sub.hpp +++ b/include/aidge/operator/Sub.hpp @@ -62,10 +62,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<Sub_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); - getInput(1)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp index ccee3865e..21a479622 100644 --- a/src/operator/OperatorTensor.cpp +++ b/src/operator/OperatorTensor.cpp @@ -149,14 +149,4 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const { for (IOIndex_t i = 0; i < nbOutputs(); ++i) { getOutput(i)->setDataType(dataType); } - /* - for (IOIndex_t i = 0; i < nbInputs(); ++i) { - if (!getInput(i)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not set"); - } - else { - getInput(i)->setDataType(dataType); - } - } - */ } \ No newline at end of file -- GitLab