From aef244d0cc1039fc5aa20f9dc4fbbf1705c5f1e2 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Thu, 22 Feb 2024 14:49:36 +0000
Subject: [PATCH] Initial working python registrar.

---
 include/aidge/operator/Conv.hpp               |  4 +--
 include/aidge/operator/Operator.hpp           | 10 +++++-
 include/aidge/utils/Registrar.hpp             | 31 ++++++++++++++++++-
 .../backend/pybind_OperatorImpl.cpp           |  2 +-
 python_binding/operator/pybind_Conv.cpp       | 10 +++---
 python_binding/operator/pybind_Operator.cpp   |  6 ++--
 6 files changed, 51 insertions(+), 12 deletions(-)

diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp
index be5fb3e39..6209d0006 100644
--- a/include/aidge/operator/Conv.hpp
+++ b/include/aidge/operator/Conv.hpp
@@ -31,7 +31,7 @@ enum class ConvAttr { StrideDims, DilationDims, InChannels, OutChannels, KernelD
 
 template <DimIdx_t DIM>
 class Conv_Op : public OperatorTensor,
-                public Registrable<Conv_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Conv_Op<DIM> &)>,
+                public Registrable<Conv_Op<DIM>, std::string, std::shared_ptr<OperatorImpl>(const Conv_Op<DIM> &)>,
                 public StaticAttributes<ConvAttr, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t,
                                        DimSize_t, std::array<DimSize_t, DIM>> {
 
@@ -245,4 +245,4 @@ const char *const EnumStrings<Aidge::ConvAttr>::data[] = {
 };
 }
 
-#endif /* AIDGE_CORE_OPERATOR_CONV_H_ */
\ No newline at end of file
+#endif /* AIDGE_CORE_OPERATOR_CONV_H_ */
diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp
index cebc2d540..d6ac25616 100644
--- a/include/aidge/operator/Operator.hpp
+++ b/include/aidge/operator/Operator.hpp
@@ -115,13 +115,21 @@ public:
     virtual void setDataType(const DataType& dataType) const = 0;
 
     /**
-     * @brief Set the a new OperatorImpl to the Operator
+     * @brief Set a new OperatorImpl to the Operator
      *
      */
     void setImpl(std::shared_ptr<OperatorImpl> impl){
         mImpl = impl;
     }
 
+    /**
+     * @brief Get the OperatorImpl of the Operator
+     *
+     */
+    std::shared_ptr<OperatorImpl> getImpl(){
+        return mImpl;
+    }
+
     /**
      * @brief Minimum amount of data from a specific input for one computation pass.
      * @param inputIdx Index of the input analysed.
diff --git a/include/aidge/utils/Registrar.hpp b/include/aidge/utils/Registrar.hpp
index 66a07eb0c..4479c3202 100644
--- a/include/aidge/utils/Registrar.hpp
+++ b/include/aidge/utils/Registrar.hpp
@@ -14,6 +14,9 @@
 
 #ifdef PYBIND
 #include <pybind11/pybind11.h>
+#include <pybind11/stl.h> // declare_registrable key can recquire stl
+#include <pybind11/functional.h>// declare_registrable allow binding of lambda fn
+
 #endif
 
 #include <functional>
@@ -57,7 +60,9 @@ struct Registrar {
     Registrar(const registrar_key& key, registrar_type func) {
         //printf("REGISTRAR: %s\n", key.c_str());
         bool newInsert;
-        std::tie(std::ignore, newInsert) = C::registry().insert(std::make_pair(key, func));
+        // std::tie(std::ignore, newInsert) = C::registry().insert(std::make_pair(key, func));
+        C::registry().erase(key);
+        C::registry().insert(std::make_pair(key, func));
         //assert(newInsert && "registrar already exists");
     }
 
@@ -79,6 +84,30 @@ struct Registrar {
         return keys;
     }
 };
+
+#ifdef PYBIND
+/**
+ * @brief Function to define register function for a registrable class
+ * Defined here to have access to this function in every module who wants
+ * to create a new registrable class.
+ *
+ * @tparam C registrable class
+ * @param m pybind module
+ * @param class_name python name of the class
+ */
+template <class C>
+void declare_registrable(py::module& m, const std::string& class_name){
+    typedef typename C::registrar_key registrar_key;
+    typedef typename C::registrar_type registrar_type;
+    m.def(("register_"+ class_name).c_str(), [](registrar_key& key, registrar_type function){
+        Registrar<C>(key, function);
+    })
+    .def(("get_keys_"+ class_name).c_str(), [](){
+        return Registrar<C>::getKeys();
+    });
+}
+#endif
+
 }
 
 #endif //AIDGE_CORE_UTILS_REGISTRAR_H_
diff --git a/python_binding/backend/pybind_OperatorImpl.cpp b/python_binding/backend/pybind_OperatorImpl.cpp
index 346100690..5e63d77b7 100644
--- a/python_binding/backend/pybind_OperatorImpl.cpp
+++ b/python_binding/backend/pybind_OperatorImpl.cpp
@@ -107,7 +107,7 @@ public:
 void init_OperatorImpl(py::module& m){
 
     py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr())
-    .def(py::init<const Operator&>())
+    .def(py::init<const Operator&>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>())
     .def("forward", &OperatorImpl::forward)
     .def("backward", &OperatorImpl::backward)
     .def("get_nb_required_data", &OperatorImpl::getNbRequiredData)
diff --git a/python_binding/operator/pybind_Conv.cpp b/python_binding/operator/pybind_Conv.cpp
index 455ea4024..56c537d60 100644
--- a/python_binding/operator/pybind_Conv.cpp
+++ b/python_binding/operator/pybind_Conv.cpp
@@ -19,13 +19,15 @@
 #include "aidge/operator/Conv.hpp"
 #include "aidge/operator/OperatorTensor.hpp"
 #include "aidge/utils/Types.h"
+#include "aidge/utils/Registrar.hpp" // declare_registrable
 
 namespace py = pybind11;
 namespace Aidge {
 
 template <DimIdx_t DIM> void declare_ConvOp(py::module &m) {
+  const std::string pyClassName("ConvOp" + std::to_string(DIM) + "D");
   py::class_<Conv_Op<DIM>, std::shared_ptr<Conv_Op<DIM>>, Attributes, OperatorTensor>(
-    m, ("ConvOp" + std::to_string(DIM) + "D").c_str(),
+    m, pyClassName.c_str(),
     py::multiple_inheritance())
   .def(py::init<DimSize_t,
                 DimSize_t,
@@ -41,6 +43,8 @@ template <DimIdx_t DIM> void declare_ConvOp(py::module &m) {
     .def("get_outputs_name", &Conv_Op<DIM>::getOutputsName)
     .def("attributes_name", &Conv_Op<DIM>::staticGetAttrsName)
     ;
+  declare_registrable<Conv_Op<DIM>>(m, pyClassName);
+
 
   m.def(("Conv" + std::to_string(DIM) + "D").c_str(), [](DimSize_t in_channels,
                                                          DimSize_t out_channels,
@@ -66,9 +70,5 @@ void init_Conv(py::module &m) {
   declare_ConvOp<1>(m);
   declare_ConvOp<2>(m);
   declare_ConvOp<3>(m);
-
-  // FIXME:
-  // m.def("Conv1D", static_cast<NodeAPI(*)(const char*, int, int, int const
-  // (&)[1])>(&Conv));
 }
 } // namespace Aidge
diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp
index 79a85cb92..05d6cd089 100644
--- a/python_binding/operator/pybind_Operator.cpp
+++ b/python_binding/operator/pybind_Operator.cpp
@@ -1,3 +1,4 @@
+
 /********************************************************************************
  * Copyright (c) 2023 CEA-List
  *
@@ -32,10 +33,11 @@ void init_Operator(py::module& m){
     .def("set_datatype", &Operator::setDataType, py::arg("dataType"))
     .def("set_backend", &Operator::setBackend, py::arg("name"), py::arg("device") = 0)
     .def("forward", &Operator::forward)
-    // py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected !
+    // py::keep_alive forbide Python to garbage collect the implementation lambda as long as the Operator is not deleted !
     .def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>())
+    .def("get_impl", &Operator::getImpl)
     .def("get_hook", &Operator::getHook)
     .def("add_hook", &Operator::addHook)
     ;
 }
-}
\ No newline at end of file
+}
-- 
GitLab