From dc3bfacba1e84407843098efdf7d8321c25f756c Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Thu, 29 Feb 2024 14:27:13 +0000
Subject: [PATCH] Make ProducerOp registrable in Python.

---
 include/aidge/operator/Producer.hpp         | 4 ++--
 python_binding/operator/pybind_Producer.cpp | 3 ++-
 2 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp
index fe9b044e2..e31d8c7d2 100644
--- a/include/aidge/operator/Producer.hpp
+++ b/include/aidge/operator/Producer.hpp
@@ -28,7 +28,7 @@ enum class ProdAttr { Constant };
 
 class Producer_Op
     : public OperatorTensor,
-      public Registrable<Producer_Op, std::string, std::unique_ptr<OperatorImpl>(
+      public Registrable<Producer_Op, std::string, std::shared_ptr<OperatorImpl>(
                                           const Producer_Op &)>,
       public StaticAttributes<ProdAttr, bool> {
 public:
@@ -88,7 +88,7 @@ public:
     inline const std::vector<DimSize_t> dims() const noexcept { return mOutputs[0]->dims(); }
 
     void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
-        mImpl = Registrar<Producer_Op>::create(name)(*this);
+        SET_IMPL_MACRO(Producer_Op, *this, name);
         mOutputs[0]->setBackend(name, device);
     }
 
diff --git a/python_binding/operator/pybind_Producer.cpp b/python_binding/operator/pybind_Producer.cpp
index 3caa438d1..025c8c5dd 100644
--- a/python_binding/operator/pybind_Producer.cpp
+++ b/python_binding/operator/pybind_Producer.cpp
@@ -26,6 +26,7 @@ void declare_Producer(py::module &m) {
     // m.def(("Producer_" + std::to_string(DIM)+"D").c_str(), py::overload_cast<shared_ptr<Node>&>(&Producer<DIM>), py::arg("dims"), py::arg("name"));
     m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::array<DimSize_t, DIM>&, const std::string&, bool)>(&Producer), py::arg("dims"), py::arg("name") = "", py::arg("constant") = false);
 
+
 }
 
 
@@ -39,7 +40,7 @@ void init_Producer(py::module &m) {
     .def("get_outputs_name", &Producer_Op::getOutputsName)
     .def("attributes_name", &Producer_Op::staticGetAttrsName);
     m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&, bool)>(&Producer), py::arg("tensor"), py::arg("name") = "", py::arg("constant") = false);
-
+    declare_registrable<Producer_Op>(m, "ProducerOp");
     declare_Producer<1>(m);
     declare_Producer<2>(m);
     declare_Producer<3>(m);
-- 
GitLab