From 519e386eb379b716b922d13d8be19d3b340e4519 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Thu, 29 Feb 2024 14:27:13 +0000 Subject: [PATCH] Make ProducerOp registrable in Python. --- include/aidge/operator/Producer.hpp | 6 ++---- include/aidge/utils/Registrar.hpp | 6 ++++-- python_binding/operator/pybind_Producer.cpp | 3 ++- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index 0731498dd..29fada1c5 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -28,7 +28,7 @@ enum class ProdAttr { Constant }; class Producer_Op : public OperatorTensor, - public Registrable<Producer_Op, std::string, std::unique_ptr<OperatorImpl>( + public Registrable<Producer_Op, std::string, std::shared_ptr<OperatorImpl>( const Producer_Op &)>, public StaticAttributes<ProdAttr, bool> { public: @@ -92,9 +92,7 @@ public: inline const std::vector<DimSize_t> dims() const noexcept { return mOutputs[0]->dims(); } void setBackend(const std::string& name, DeviceIdx_t device = 0) override { - if (Registrar<Producer_Op>::exists({name})) { - mImpl = Registrar<Producer_Op>::create({name})(*this); - } + SET_IMPL_MACRO(Producer_Op, *this, name); mOutputs[0]->setBackend(name, device); } diff --git a/include/aidge/utils/Registrar.hpp b/include/aidge/utils/Registrar.hpp index b2a98113f..a5bd260ec 100644 --- a/include/aidge/utils/Registrar.hpp +++ b/include/aidge/utils/Registrar.hpp @@ -139,8 +139,10 @@ void declare_registrable(py::module& m, const std::string& class_name){ (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \ } #else -#define SET_IMPL_MACRO(T_Op, op, backend_name) \ - (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); +#define SET_IMPL_MACRO(T_Op, op, backend_name) \ + if (Registrar<T_Op>::exists(backend_name)) { \ + (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \ + } #endif } diff --git a/python_binding/operator/pybind_Producer.cpp b/python_binding/operator/pybind_Producer.cpp index 3caa438d1..025c8c5dd 100644 --- a/python_binding/operator/pybind_Producer.cpp +++ b/python_binding/operator/pybind_Producer.cpp @@ -26,6 +26,7 @@ void declare_Producer(py::module &m) { // m.def(("Producer_" + std::to_string(DIM)+"D").c_str(), py::overload_cast<shared_ptr<Node>&>(&Producer<DIM>), py::arg("dims"), py::arg("name")); m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::array<DimSize_t, DIM>&, const std::string&, bool)>(&Producer), py::arg("dims"), py::arg("name") = "", py::arg("constant") = false); + } @@ -39,7 +40,7 @@ void init_Producer(py::module &m) { .def("get_outputs_name", &Producer_Op::getOutputsName) .def("attributes_name", &Producer_Op::staticGetAttrsName); m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&, bool)>(&Producer), py::arg("tensor"), py::arg("name") = "", py::arg("constant") = false); - + declare_registrable<Producer_Op>(m, "ProducerOp"); declare_Producer<1>(m); declare_Producer<2>(m); declare_Producer<3>(m); -- GitLab