diff --git a/include/aidge/operator/GlobalAveragePooling.hpp b/include/aidge/operator/GlobalAveragePooling.hpp index 718782372c1eca2f73d2dd382f7525fbefb3e8a3..12c8eb02d9488edeb760b6a063cfac5f8257db18 100644 --- a/include/aidge/operator/GlobalAveragePooling.hpp +++ b/include/aidge/operator/GlobalAveragePooling.hpp @@ -17,8 +17,6 @@ #include <vector> #include "aidge/backend/OperatorImpl.hpp" -#include "aidge/data/Data.hpp" -#include "aidge/data/Tensor.hpp" #include "aidge/graph/Node.hpp" #include "aidge/operator/OperatorTensor.hpp" #include "aidge/utils/Registrar.hpp" @@ -43,9 +41,9 @@ public: GlobalAveragePooling_Op(const GlobalAveragePooling_Op &op) : OperatorTensor(op) { - if (op.mImpl){ - SET_IMPL_MACRO(GlobalAveragePooling_Op, *this, op.mOutputs[0]->getImpl()->backend()); - }else{ + if (op.mImpl) { + SET_IMPL_MACRO(GlobalAveragePooling_Op, *this, op.backend()); + } else { mImpl = nullptr; } } @@ -56,10 +54,7 @@ public: void computeOutputDims() override final; - void setBackend(const std::string &name, DeviceIdx_t device = 0) override { - SET_IMPL_MACRO(GlobalAveragePooling_Op, *this, name); - mOutputs[0]->setBackend(name, device); - } + void setBackend(const std::string &name, DeviceIdx_t device = 0) override final; static const std::vector<std::string> getInputsName() { return {"data_input"}; diff --git a/src/operator/GlobalAveragePooling.cpp b/src/operator/GlobalAveragePooling.cpp index da760a4c89203c0415bb9a0259e25d5e7908b7d6..5781f014483164142187f07edb402cbde086dc43 100644 --- a/src/operator/GlobalAveragePooling.cpp +++ b/src/operator/GlobalAveragePooling.cpp @@ -40,3 +40,8 @@ void Aidge::GlobalAveragePooling_Op::computeOutputDims() { mOutputs[0]->resize(out_dims); } } + +void Aidge::GlobalAveragePooling_Op::setBackend(const std::string &name, DeviceIdx_t device) { + SET_IMPL_MACRO(GlobalAveragePooling_Op, *this, name); + mOutputs[0]->setBackend(name, device); +} \ No newline at end of file