From fe867fc3febf10b497e98a773f108bdf11d0aaa7 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Wed, 29 Jan 2025 18:39:02 +0100
Subject: [PATCH] Added missing binding for Elts struct

---
 include/aidge/scheduler/ProdConso.hpp         |  4 +
 python_binding/data/pybind_Elts.cpp           | 85 +++++++++++++++++++
 python_binding/pybind_core.cpp                |  2 +
 python_binding/scheduler/pybind_ProdConso.cpp |  1 +
 4 files changed, 92 insertions(+)
 create mode 100644 python_binding/data/pybind_Elts.cpp

diff --git a/include/aidge/scheduler/ProdConso.hpp b/include/aidge/scheduler/ProdConso.hpp
index f30e00afa..bc42cb36c 100644
--- a/include/aidge/scheduler/ProdConso.hpp
+++ b/include/aidge/scheduler/ProdConso.hpp
@@ -34,6 +34,10 @@ public:
         return std::make_unique<ProdConso>(op, true);
     }
 
+    const Operator& getOperator() const noexcept {
+        return mOp;
+    }
+
     /**
      * @brief Minimum amount of data from a specific input required by the
      * implementation to be run.
diff --git a/python_binding/data/pybind_Elts.cpp b/python_binding/data/pybind_Elts.cpp
new file mode 100644
index 000000000..59a8211e2
--- /dev/null
+++ b/python_binding/data/pybind_Elts.cpp
@@ -0,0 +1,85 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <algorithm>  // std::transform
+#include <cctype>     // std::tolower
+#include <string>     // std::string
+#include <vector>
+
+#include <pybind11/pybind11.h>
+#include <pybind11/stl.h>
+#include <pybind11/operators.h>
+
+#include "aidge/data/Elts.hpp"
+
+namespace py = pybind11;
+namespace Aidge {
+
+template <class T>
+void bindEnum(py::module& m, const std::string& name) {
+    // Define enumeration names for python as lowercase type name
+    // This defined enum names compatible with basic numpy type
+    // name such as: float32, flot64, [u]int32, [u]int64, ...
+    auto python_enum_name = [](const T& type) {
+        auto str_lower = [](std::string& str) {
+            std::transform(str.begin(), str.end(), str.begin(),
+                           [](unsigned char c){
+                               return std::tolower(c);
+                           });
+        };
+        auto type_name = std::string(Aidge::format_as(type));
+        str_lower(type_name);
+        return type_name;
+    };
+
+    // Auto generate enumeration names from lowercase type strings
+    std::vector<std::string> enum_names;
+    for (auto type_str : EnumStrings<T>::data) {
+        auto type = static_cast<T>(enum_names.size());
+        auto enum_name = python_enum_name(type);
+        enum_names.push_back(enum_name);
+    }
+
+    // Define python side enumeration aidge_core.type
+    auto e_type = py::enum_<T>(m, name.c_str());
+
+    // Add enum value for each enum name
+    for (std::size_t idx = 0; idx < enum_names.size(); idx++) {
+        e_type.value(enum_names[idx].c_str(), static_cast<T>(idx));
+    }
+
+    // Define str() to return the bare enum name value, it allows
+    // to compare directly for instance str(tensor.type())
+    // with str(nparray.type)
+    e_type.def("__str__", [enum_names](const T& type) {
+        return enum_names[static_cast<int>(type)];
+    }, py::prepend());
+}
+
+void init_Elts(py::module& m) {
+    bindEnum<Elts_t::EltType>(m, "EltType");
+    m.def("format_as", (const char* (*)(Elts_t::EltType)) &format_as, py::arg("elt"));
+    
+    py::class_<Elts_t, std::shared_ptr<Elts_t>>(
+        m, "Elts_t", py::dynamic_attr())
+        .def_static("none_elts", &Elts_t::NoneElts)
+        .def_static("data_elts", &Elts_t::DataElts, py::arg("data"), py::arg("token") = 1)
+        .def_static("token_elts", &Elts_t::TokenElts, py::arg("token"))
+        .def_readwrite("data", &Elts_t::data)
+        .def_readwrite("token", &Elts_t::token)
+        .def_readwrite("type", &Elts_t::type)
+        .def(py::self + py::self)
+        .def(py::self += py::self)
+        .def(py::self < py::self)
+        .def(py::self > py::self);
+}
+
+} // namespace Aidge
diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp
index 1f35373f3..cc6f0bf25 100644
--- a/python_binding/pybind_core.cpp
+++ b/python_binding/pybind_core.cpp
@@ -21,6 +21,7 @@ void init_Random(py::module&);
 void init_Data(py::module&);
 void init_DataFormat(py::module&);
 void init_DataType(py::module&);
+void init_Elts(py::module&);
 void init_Database(py::module&);
 void init_DataProvider(py::module&);
 void init_Interpolation(py::module&);
@@ -112,6 +113,7 @@ void init_Aidge(py::module& m) {
     init_Data(m);
     init_DataFormat(m);
     init_DataType(m);
+    init_Elts(m);
     init_Database(m);
     init_DataProvider(m);
     init_Interpolation(m);
diff --git a/python_binding/scheduler/pybind_ProdConso.cpp b/python_binding/scheduler/pybind_ProdConso.cpp
index abd6d5379..547e2258d 100644
--- a/python_binding/scheduler/pybind_ProdConso.cpp
+++ b/python_binding/scheduler/pybind_ProdConso.cpp
@@ -104,6 +104,7 @@ void init_ProdConso(py::module& m){
     .def(py::init<const Operator&, bool>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>(), py::keep_alive<1,3>())
     .def_static("default_model", &ProdConso::defaultModel)
     .def_static("in_place_model", &ProdConso::inPlaceModel)
+    .def("get_operator", &ProdConso::getOperator)
     .def("get_nb_required_data", &ProdConso::getNbRequiredData)
     .def("get_nb_required_protected", &ProdConso::getNbRequiredProtected)
     .def("get_required_memory", &ProdConso::getRequiredMemory)
-- 
GitLab