From 37a9cafe87df52ca15cbe1afc3538032233df7f5 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Fri, 5 Jan 2024 16:13:00 +0000 Subject: [PATCH] Update Operators with device function --- include/aidge/operator/Erf.hpp | 7 ++----- include/aidge/operator/Gather.hpp | 10 +++------- include/aidge/operator/ReduceMean.hpp | 9 +++------ include/aidge/operator/Reshape.hpp | 7 ++----- include/aidge/operator/Transpose.hpp | 7 ++----- 5 files changed, 12 insertions(+), 28 deletions(-) diff --git a/include/aidge/operator/Erf.hpp b/include/aidge/operator/Erf.hpp index 6395756f3..6995cea5e 100644 --- a/include/aidge/operator/Erf.hpp +++ b/include/aidge/operator/Erf.hpp @@ -51,12 +51,9 @@ public: return std::make_shared<Erf_Op>(*this); } - void setBackend(const std::string& name) override { + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Erf_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Gather.hpp b/include/aidge/operator/Gather.hpp index f82762228..20082eed2 100644 --- a/include/aidge/operator/Gather.hpp +++ b/include/aidge/operator/Gather.hpp @@ -40,7 +40,7 @@ public: Gather_Op() = delete; - + using Attributes_ = StaticAttributes<GatherAttr, int>; template <GatherAttr e> using attr = typename Attributes_::template attr<e>; Gather_Op(int axis) @@ -70,13 +70,9 @@ public: void computeOutputDims() override final; - void setBackend(const std::string& name) override { + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Gather_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); - getInput(1)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/ReduceMean.hpp b/include/aidge/operator/ReduceMean.hpp index 0acd21b28..52d011874 100644 --- a/include/aidge/operator/ReduceMean.hpp +++ b/include/aidge/operator/ReduceMean.hpp @@ -89,7 +89,7 @@ class ReduceMean_Op : public OperatorTensor, } else outDims.push_back(getInput(0)->dims()[d]); - } + } if(outDims.size()>0) mOutputs[0]->resize(outDims); else @@ -97,12 +97,9 @@ class ReduceMean_Op : public OperatorTensor, } } - void setBackend(const std::string &name) override { + void setBackend(const std::string &name, DeviceIdx_t device = 0) override { mImpl = Registrar<ReduceMean_Op<DIM>>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Reshape.hpp b/include/aidge/operator/Reshape.hpp index 1ffa04596..32d71d5ad 100644 --- a/include/aidge/operator/Reshape.hpp +++ b/include/aidge/operator/Reshape.hpp @@ -66,12 +66,9 @@ public: void computeOutputDims() override final; - void setBackend(const std::string& name) override { + void setBackend(const std::string& name, DeviceIdx_t device = 0) override { mImpl = Registrar<Reshape_Op>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Transpose.hpp b/include/aidge/operator/Transpose.hpp index f111be76c..2262bec14 100644 --- a/include/aidge/operator/Transpose.hpp +++ b/include/aidge/operator/Transpose.hpp @@ -79,12 +79,9 @@ class Transpose_Op : public OperatorTensor, } } - void setBackend(const std::string &name) override { + void setBackend(const std::string &name, DeviceIdx_t device = 0) override { mImpl = Registrar<Transpose_Op<DIM>>::create(name)(*this); - mOutputs[0]->setBackend(name); - - // FIXME: temporary workaround - getInput(0)->setBackend(name); + mOutputs[0]->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ -- GitLab