diff --git a/include/aidge/operator/GlobalAveragePooling.hpp b/include/aidge/operator/GlobalAveragePooling.hpp index ee26d89f84f7632ddc5f9eafb911152211e26178..85441b340211f5f99ca15f69339c3d22e5fb99c7 100644 --- a/include/aidge/operator/GlobalAveragePooling.hpp +++ b/include/aidge/operator/GlobalAveragePooling.hpp @@ -34,7 +34,7 @@ namespace Aidge { class GlobalAveragePooling_Op : public OperatorTensor, public Registrable<GlobalAveragePooling_Op, std::string, - std::unique_ptr<OperatorImpl>( + std::shared_ptr<OperatorImpl>( const GlobalAveragePooling_Op &)> { public: static const std::string Type; @@ -43,9 +43,11 @@ public: GlobalAveragePooling_Op(const GlobalAveragePooling_Op &op) : OperatorTensor(op) { - mImpl = op.mImpl ? Registrar<GlobalAveragePooling_Op>::create( - op.mOutputs[0]->getImpl()->backend())(*this) - : nullptr; + if (op.mImpl){ + SET_IMPL_MACRO(GlobalAveragePooling_Op, *this, op.mOutputs[0]->getImpl()->backend()); + }else{ + mImpl = nullptr; + } } std::shared_ptr<Operator> clone() const override { @@ -55,7 +57,7 @@ public: void computeOutputDims() override final; void setBackend(const std::string &name, DeviceIdx_t device = 0) override { - mImpl = Registrar<GlobalAveragePooling_Op>::create(name)(*this); + mImpl = SET_IMPL_MACRO(GlobalAveragePooling_Op, *this, name); mOutputs[0]->setBackend(name, device); } diff --git a/python_binding/operator/pybind_GlobalAveragePooling.cpp b/python_binding/operator/pybind_GlobalAveragePooling.cpp index fcbd31a9ae6486ad36f8efca32d748476a32eaae..dda47d794e94056aee49565a5c8b9856ffffca47 100644 --- a/python_binding/operator/pybind_GlobalAveragePooling.cpp +++ b/python_binding/operator/pybind_GlobalAveragePooling.cpp @@ -18,13 +18,14 @@ namespace py = pybind11; namespace Aidge { +const std::string pyClassName("GlobalAveragePoolingOp"); void init_GlobalAveragePooling(py::module &m) { py::class_<GlobalAveragePooling_Op, std::shared_ptr<GlobalAveragePooling_Op>, - OperatorTensor>(m, "GlobalAveragePooling", + OperatorTensor>(m, pyClassName.c_str, py::multiple_inheritance()) .def("get_inputs_name", &GlobalAveragePooling_Op::getInputsName) .def("get_outputs_name", &GlobalAveragePooling_Op::getOutputsName); - + declare_registrable<GlobalAveragePooling_Op>(m, pyClassName); m.def("globalaveragepooling", &GlobalAveragePooling, py::arg("name") = ""); } } // namespace Aidge