From af1ba81f4aaa7685622b8fb8b87c8d3c41956cf5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gr=C3=A9goire=20KUBLER?= <gregoire.kubler@proton.me> Date: Tue, 26 Mar 2024 17:49:45 +0100 Subject: [PATCH] feat : add support for py registrar to GlobalAveragePooling_Op --- include/aidge/operator/GlobalAveragePooling.hpp | 12 +++++++----- .../operator/pybind_GlobalAveragePooling.cpp | 5 +++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/include/aidge/operator/GlobalAveragePooling.hpp b/include/aidge/operator/GlobalAveragePooling.hpp index ee26d89f8..85441b340 100644 --- a/include/aidge/operator/GlobalAveragePooling.hpp +++ b/include/aidge/operator/GlobalAveragePooling.hpp @@ -34,7 +34,7 @@ namespace Aidge { class GlobalAveragePooling_Op : public OperatorTensor, public Registrable<GlobalAveragePooling_Op, std::string, - std::unique_ptr<OperatorImpl>( + std::shared_ptr<OperatorImpl>( const GlobalAveragePooling_Op &)> { public: static const std::string Type; @@ -43,9 +43,11 @@ public: GlobalAveragePooling_Op(const GlobalAveragePooling_Op &op) : OperatorTensor(op) { - mImpl = op.mImpl ? Registrar<GlobalAveragePooling_Op>::create( - op.mOutputs[0]->getImpl()->backend())(*this) - : nullptr; + if (op.mImpl){ + SET_IMPL_MACRO(GlobalAveragePooling_Op, *this, op.mOutputs[0]->getImpl()->backend()); + }else{ + mImpl = nullptr; + } } std::shared_ptr<Operator> clone() const override { @@ -55,7 +57,7 @@ public: void computeOutputDims() override final; void setBackend(const std::string &name, DeviceIdx_t device = 0) override { - mImpl = Registrar<GlobalAveragePooling_Op>::create(name)(*this); + mImpl = SET_IMPL_MACRO(GlobalAveragePooling_Op, *this, name); mOutputs[0]->setBackend(name, device); } diff --git a/python_binding/operator/pybind_GlobalAveragePooling.cpp b/python_binding/operator/pybind_GlobalAveragePooling.cpp index fcbd31a9a..dda47d794 100644 --- a/python_binding/operator/pybind_GlobalAveragePooling.cpp +++ b/python_binding/operator/pybind_GlobalAveragePooling.cpp @@ -18,13 +18,14 @@ namespace py = pybind11; namespace Aidge { +const std::string pyClassName("GlobalAveragePoolingOp"); void init_GlobalAveragePooling(py::module &m) { py::class_<GlobalAveragePooling_Op, std::shared_ptr<GlobalAveragePooling_Op>, - OperatorTensor>(m, "GlobalAveragePooling", + OperatorTensor>(m, pyClassName.c_str, py::multiple_inheritance()) .def("get_inputs_name", &GlobalAveragePooling_Op::getInputsName) .def("get_outputs_name", &GlobalAveragePooling_Op::getOutputsName); - + declare_registrable<GlobalAveragePooling_Op>(m, pyClassName); m.def("globalaveragepooling", &GlobalAveragePooling, py::arg("name") = ""); } } // namespace Aidge -- GitLab