diff --git a/include/aidge/operator/Erf.hpp b/include/aidge/operator/Erf.hpp index 6395756f3b08c5838d390ab45d38fa9c03cb91cb..6995cea5e4af9a17cf3d24516d9840850e701669 100644 --- a/include/aidge/operator/Erf.hpp +++ b/include/aidge/operator/Erf.hpp @@ -51,12 +51,9 @@ public: return std::make_shared<Erf_Op>(*this); } - void setBackend(const std::string& name) override { + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Erf_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Gather.hpp b/include/aidge/operator/Gather.hpp index f8276222811f6cc02c062d85e7ae99d72edead7a..20082eed28825ade9d62fb5d4e081840d3bd4442 100644 --- a/include/aidge/operator/Gather.hpp +++ b/include/aidge/operator/Gather.hpp @@ -40,7 +40,7 @@ public: Gather_Op() = delete; - + using Attributes_ = StaticAttributes<GatherAttr, int>; template <GatherAttr e> using attr = typename Attributes_::template attr<e>; Gather_Op(int axis) @@ -70,13 +70,9 @@ public: void computeOutputDims() override final; - void setBackend(const std::string& name) override { + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Gather_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); - getInput(1)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/ReduceMean.hpp b/include/aidge/operator/ReduceMean.hpp index 0acd21b28fac54e7e6d30e8219ead0e04ef777f6..52d0118743373c23a4afe4a51d3f22adbe9e6848 100644 --- a/include/aidge/operator/ReduceMean.hpp +++ b/include/aidge/operator/ReduceMean.hpp @@ -89,7 +89,7 @@ class ReduceMean_Op : public OperatorTensor, } else outDims.push_back(getInput(0)->dims()[d]); - } + } if(outDims.size()>0) mOutputs[0]->resize(outDims); else @@ -97,12 +97,9 @@ class ReduceMean_Op : public OperatorTensor, } } - void setBackend(const std::string &name) override { + void setBackend(const std::string &name, DeviceIdx_t device = 0) override { mImpl = Registrar<ReduceMean_Op<DIM>>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Reshape.hpp b/include/aidge/operator/Reshape.hpp index 1ffa045960037f35167ae2d6e8904c49e2c55560..32d71d5adc3cfd92c9840dcb5bc61bfb6399c6db 100644 --- a/include/aidge/operator/Reshape.hpp +++ b/include/aidge/operator/Reshape.hpp @@ -66,12 +66,9 @@ public: void computeOutputDims() override final; - void setBackend(const std::string& name) override { + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Reshape_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Transpose.hpp b/include/aidge/operator/Transpose.hpp index f111be76cd712265e92e2e4c3e0220f79e13b1f7..2262bec14bd2f00cda643ade0709f7f9d509fa22 100644 --- a/include/aidge/operator/Transpose.hpp +++ b/include/aidge/operator/Transpose.hpp @@ -79,12 +79,9 @@ class Transpose_Op : public OperatorTensor, } } - void setBackend(const std::string &name) override { + void setBackend(const std::string &name, DeviceIdx_t device = 0) override { mImpl = Registrar<Transpose_Op<DIM>>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){