From 519e386eb379b716b922d13d8be19d3b340e4519 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Thu, 29 Feb 2024 14:27:13 +0000
Subject: [PATCH] Make ProducerOp registrable in Python.

---
 include/aidge/operator/Producer.hpp         | 6 ++----
 include/aidge/utils/Registrar.hpp           | 6 ++++--
 python_binding/operator/pybind_Producer.cpp | 3 ++-
 3 files changed, 8 insertions(+), 7 deletions(-)

diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp
index 0731498dd..29fada1c5 100644
--- a/include/aidge/operator/Producer.hpp
+++ b/include/aidge/operator/Producer.hpp
@@ -28,7 +28,7 @@ enum class ProdAttr { Constant };
 
 class Producer_Op
     : public OperatorTensor,
-      public Registrable<Producer_Op, std::string, std::unique_ptr<OperatorImpl>(
+      public Registrable<Producer_Op, std::string, std::shared_ptr<OperatorImpl>(
                                           const Producer_Op &)>,
       public StaticAttributes<ProdAttr, bool> {
 public:
@@ -92,9 +92,7 @@ public:
     inline const std::vector<DimSize_t> dims() const noexcept { return mOutputs[0]->dims(); }
 
     void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
-        if (Registrar<Producer_Op>::exists({name})) {
-            mImpl = Registrar<Producer_Op>::create({name})(*this);
-        }
+        SET_IMPL_MACRO(Producer_Op, *this, name);
         mOutputs[0]->setBackend(name, device);
     }
 
diff --git a/include/aidge/utils/Registrar.hpp b/include/aidge/utils/Registrar.hpp
index b2a98113f..a5bd260ec 100644
--- a/include/aidge/utils/Registrar.hpp
+++ b/include/aidge/utils/Registrar.hpp
@@ -139,8 +139,10 @@ void declare_registrable(py::module& m, const std::string& class_name){
             (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \
         }
 #else
-#define SET_IMPL_MACRO(T_Op, op, backend_name) \
-    (op).setImpl(Registrar<T_Op>::create(backend_name)(op));
+#define SET_IMPL_MACRO(T_Op, op, backend_name)                          \
+    if (Registrar<T_Op>::exists(backend_name)) {                        \
+        (op).setImpl(Registrar<T_Op>::create(backend_name)(op));        \
+    }
 #endif
 
 }
diff --git a/python_binding/operator/pybind_Producer.cpp b/python_binding/operator/pybind_Producer.cpp
index 3caa438d1..025c8c5dd 100644
--- a/python_binding/operator/pybind_Producer.cpp
+++ b/python_binding/operator/pybind_Producer.cpp
@@ -26,6 +26,7 @@ void declare_Producer(py::module &m) {
     // m.def(("Producer_" + std::to_string(DIM)+"D").c_str(), py::overload_cast<shared_ptr<Node>&>(&Producer<DIM>), py::arg("dims"), py::arg("name"));
     m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::array<DimSize_t, DIM>&, const std::string&, bool)>(&Producer), py::arg("dims"), py::arg("name") = "", py::arg("constant") = false);
 
+
 }
 
 
@@ -39,7 +40,7 @@ void init_Producer(py::module &m) {
     .def("get_outputs_name", &Producer_Op::getOutputsName)
     .def("attributes_name", &Producer_Op::staticGetAttrsName);
     m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&, bool)>(&Producer), py::arg("tensor"), py::arg("name") = "", py::arg("constant") = false);
-
+    declare_registrable<Producer_Op>(m, "ProducerOp");
     declare_Producer<1>(m);
     declare_Producer<2>(m);
     declare_Producer<3>(m);
-- 
GitLab