diff --git a/include/aidge/operator/Add.hpp b/include/aidge/operator/Add.hpp index 7859ec91b13aa6a533be5170c528bc23646b4cdf..4e91182686f436718e97a739cf50da6f25c31608 100644 --- a/include/aidge/operator/Add.hpp +++ b/include/aidge/operator/Add.hpp @@ -79,11 +79,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<Add_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - for (std::size_t i = 0; i < nbInputs(); ++i) { - getInput(i)->setBackend(name, device); - } } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/AvgPooling.hpp b/include/aidge/operator/AvgPooling.hpp index 3a1303c84ce658ee63a77befebdd4a0ddbb98c5a..e3202f76bdfe6d9fa1591880faec300ca8eee614 100644 --- a/include/aidge/operator/AvgPooling.hpp +++ b/include/aidge/operator/AvgPooling.hpp @@ -137,9 +137,6 @@ public: void setBackend(const std::string &name, int device = 0) override { mImpl = Registrar<AvgPooling_Op<DIM>>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp index a14c4440f57badf410ee0b6b0ab5e985f7984425..e9e1c770447fa42bb93b395654f0688cd039e702 100644 --- a/include/aidge/operator/BatchNorm.hpp +++ b/include/aidge/operator/BatchNorm.hpp @@ -98,13 +98,23 @@ public: mImpl = Registrar<BatchNorm_Op<DIM>>::create(name)(*this); mOutputs[0]->setBackend(name, device); - // FIXME: temporary workaround + // By default, automatically set backend for scale, shift, mean and variance getInput(1)->setBackend(name, device); getInput(2)->setBackend(name, device); getInput(3)->setBackend(name, device); getInput(4)->setBackend(name, device); } + void setDataType(const DataType& dt) const override { + mOutputs[0]->setDataType(dt); + + // By default, automatically set data type for scale, shift, mean and variance + getInput(1)->setDataType(dt); + getInput(2)->setDataType(dt); + getInput(3)->setDataType(dt); + getInput(4)->setDataType(dt); + } + static const std::vector<std::string> getInputsName() { return {"data_input", "scale", "shift", "mean", "variance"}; } diff --git a/include/aidge/operator/Concat.hpp b/include/aidge/operator/Concat.hpp index 465ff050877abac715924f0bda01ecad0d00308b..9d0eed63d27644b3e1cef0a5e1e8144f91aa1784 100644 --- a/include/aidge/operator/Concat.hpp +++ b/include/aidge/operator/Concat.hpp @@ -104,11 +104,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<Concat_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - for (std::size_t i = 0; i < nbInputs(); ++i) { - getInput(i)->setBackend(name, device); - } } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Div.hpp b/include/aidge/operator/Div.hpp index 62d174d4c901b44fa4a2255db947180684b89267..5bb0efd803f64f549b34fe775bdb2f42590f38f6 100644 --- a/include/aidge/operator/Div.hpp +++ b/include/aidge/operator/Div.hpp @@ -57,10 +57,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<Div_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); - getInput(1)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/LeakyReLU.hpp b/include/aidge/operator/LeakyReLU.hpp index 93a9958e097068715d1d3e708ccbdffa6f51546e..a8bc0a477c823069318bc8f392140a99ffbc7bd7 100644 --- a/include/aidge/operator/LeakyReLU.hpp +++ b/include/aidge/operator/LeakyReLU.hpp @@ -70,9 +70,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<LeakyReLU_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/MatMul.hpp b/include/aidge/operator/MatMul.hpp index ed077b2145f9ef0691d8ade1b64f840eecb3e177..3d1ed900a171caf340e953063ed397fbbb175904 100644 --- a/include/aidge/operator/MatMul.hpp +++ b/include/aidge/operator/MatMul.hpp @@ -86,10 +86,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<MatMul_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); - getInput(1)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/MaxPooling.hpp b/include/aidge/operator/MaxPooling.hpp index 76ee8cb83afa95a397cc7e5b92f858b401baa213..4494371803825e5380c54fef7741084a97c5a975 100644 --- a/include/aidge/operator/MaxPooling.hpp +++ b/include/aidge/operator/MaxPooling.hpp @@ -107,9 +107,6 @@ public: void setBackend(const std::string &name, int device = 0) override { mImpl = Registrar<MaxPooling_Op<DIM>>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Mul.hpp b/include/aidge/operator/Mul.hpp index 86365f3453ebfc3e584f18234f009aa0c83511d8..41049c43f110975cc2f2eda0a7f0d4c92c81c1ad 100644 --- a/include/aidge/operator/Mul.hpp +++ b/include/aidge/operator/Mul.hpp @@ -59,10 +59,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<Mul_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); - getInput(1)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Pad.hpp b/include/aidge/operator/Pad.hpp index ab54853a32586caafbdf70ce18d3966b1c4927b2..9f49cb9a9db918c232e5590d4f013c899082b80a 100644 --- a/include/aidge/operator/Pad.hpp +++ b/include/aidge/operator/Pad.hpp @@ -100,9 +100,6 @@ public: void setBackend(const std::string &name, int device = 0) override { mImpl = Registrar<Pad_Op<DIM>>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Pow.hpp b/include/aidge/operator/Pow.hpp index a36f8ee9fa7f8a85bed7f784f3176d8c8754514e..464c49909e2cb15e2da88e8533ce10459860d875 100644 --- a/include/aidge/operator/Pow.hpp +++ b/include/aidge/operator/Pow.hpp @@ -57,10 +57,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<Pow_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); - getInput(1)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/ReLU.hpp b/include/aidge/operator/ReLU.hpp index 41dd240314dbdf1f46949f7bae32616b4c3b6f6e..8a8f3f8544e8581518850bbfb8b2981f2a039fc5 100644 --- a/include/aidge/operator/ReLU.hpp +++ b/include/aidge/operator/ReLU.hpp @@ -54,9 +54,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<ReLU_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Scaling.hpp b/include/aidge/operator/Scaling.hpp index 0b1f710b174ed04dd2055f77a7d6f619ad403b76..6c49b784849b766615bb8baf4e8c7b0ba216d489 100644 --- a/include/aidge/operator/Scaling.hpp +++ b/include/aidge/operator/Scaling.hpp @@ -69,8 +69,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<Scaling_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - // FIXME: temporary workaround - mInputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName() { diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp index 604f0e141c087b8d3184333a97b1a48d2cde66e7..e48de88c8726ba548a84d42abfe4feed2cb6d220 100644 --- a/include/aidge/operator/Slice.hpp +++ b/include/aidge/operator/Slice.hpp @@ -93,9 +93,6 @@ public: void setBackend(const std::string &name, int device = 0) override { mImpl = Registrar<Slice_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Softmax.hpp b/include/aidge/operator/Softmax.hpp index 48e4aa32392c48907bedd52d07ee3ef0a730573a..d4716dc5cb0908c0c64c332f50177280ebf0fa62 100644 --- a/include/aidge/operator/Softmax.hpp +++ b/include/aidge/operator/Softmax.hpp @@ -54,9 +54,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<Softmax_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Sqrt.hpp b/include/aidge/operator/Sqrt.hpp index 9db8c37535389d24ff9e2146f719f46254647700..b679b3d6e2dd62d5290cc89f39a1025b18e8c37a 100644 --- a/include/aidge/operator/Sqrt.hpp +++ b/include/aidge/operator/Sqrt.hpp @@ -59,9 +59,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<Sqrt_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Sub.hpp b/include/aidge/operator/Sub.hpp index d55a29eeef33e5d41a16621938981d568fc54846..7eb7c8f9b1756d061b7abe30dc002785fc2c7573 100644 --- a/include/aidge/operator/Sub.hpp +++ b/include/aidge/operator/Sub.hpp @@ -62,10 +62,6 @@ public: void setBackend(const std::string& name, int device = 0) override { mImpl = Registrar<Sub_Op>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(0)->setBackend(name, device); - getInput(1)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp index ccee3865e5db112bd4d1b48e76db345e2313e6be..21a47962228949c1ae4256b4d9ef053fbf50ce76 100644 --- a/src/operator/OperatorTensor.cpp +++ b/src/operator/OperatorTensor.cpp @@ -149,14 +149,4 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const { for (IOIndex_t i = 0; i < nbOutputs(); ++i) { getOutput(i)->setDataType(dataType); } - /* - for (IOIndex_t i = 0; i < nbInputs(); ++i) { - if (!getInput(i)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not set"); - } - else { - getInput(i)->setDataType(dataType); - } - } - */ } \ No newline at end of file