From 9a5e0331e8a093eeafb03c176fea635dfc4f3464 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Mon, 2 Oct 2023 18:49:09 +0200
Subject: [PATCH] Universal Python binding

---
 include/aidge/utils/DynamicAttributes.hpp | 120 +++++++++++++++++-----
 python_binding/utils/pybind_Parameter.cpp |  24 +----
 2 files changed, 97 insertions(+), 47 deletions(-)

diff --git a/include/aidge/utils/DynamicAttributes.hpp b/include/aidge/utils/DynamicAttributes.hpp
index 1c3ac472e..37c4104b9 100644
--- a/include/aidge/utils/DynamicAttributes.hpp
+++ b/include/aidge/utils/DynamicAttributes.hpp
@@ -22,6 +22,13 @@
 #include "aidge/utils/Any.hpp"
 #include "aidge/utils/Attributes.hpp"
 
+#ifdef PYBIND
+#include <pybind11/pybind11.h>
+#include <pybind11/stl.h>
+
+namespace py = pybind11;
+#endif
+
 
 namespace Aidge {
 
@@ -37,16 +44,40 @@ public:
      *  exist
      * \note at() throws if the Attribute does not exist, using find to test for Attribute existance
      */
-    template<class T> T& getAttr(const std::string& name)
+    template<class T> T getAttr(const std::string& name) const
     {
-        return libany::any_cast<T&>(mAttrs.at(name));
+#ifdef PYBIND
+        // If attribute does not exist in C++, it might have been created in Python
+        auto it = mAttrs.find(name);
+        if (it == mAttrs.end()) {
+            auto itPy = mAttrsPy.find(name);
+            if (itPy != mAttrsPy.end()) {
+                return itPy->second.cast<T>();
+            }
+        }
+#endif
+
+        return libany::any_cast<T>(mAttrs.at(name));
     }
 
+    // Note: return by reference is not possible because py::object::cast() returns a temporary
+/*
     template<class T> const T& getAttr(const std::string& name) const
     {
+#ifdef PYBIND
+        // If attribute does not exist in C++, it might have been created in Python
+        auto it = mAttrs.find(name);
+        if (it == mAttrs.end()) {
+            auto itPy = mAttrsPy.find(name);
+            if (itPy != mAttrsPy.end()) {
+                return itPy->second.cast<T>();
+            }
+        }
+#endif
+
         return libany::any_cast<const T&>(mAttrs.at(name));
     }
-
+*/
     ///\brief Add a new Attribute, identified by its name. If it already exists, asserts.
     ///\tparam T expected Attribute type
     ///\param name Attribute name
@@ -55,6 +86,11 @@ public:
     {
         const auto& res = mAttrs.emplace(std::make_pair(name, libany::any(std::forward<T>(value))));
         assert(res.second && "attribute already exists");
+
+#ifdef PYBIND
+        // Keep a copy of the attribute in py::object that is updated everytime
+        mAttrsPy.emplace(std::make_pair(name, py::cast(value)));
+#endif
     }
 
     ///\brief Set an Attribute value, identified by its name. If it already exists, its value (and type, if different) is changed.
@@ -66,59 +102,89 @@ public:
         auto res = mAttrs.emplace(std::make_pair(name, libany::any(std::forward<T>(value))));
         if (!res.second)
             res.first->second = std::move(libany::any(std::forward<T>(value)));
+
+#ifdef PYBIND
+        // Keep a copy of the attribute in py::object that is updated everytime
+        auto resPy = mAttrsPy.emplace(std::make_pair(name, py::cast(value)));
+        if (!resPy.second)
+            resPy.first->second = std::move(py::cast(value));
+#endif
+    }
+
+#ifdef PYBIND
+    void addAttrPy(const std::string& name, py::object&& value)
+    {
+        mAttrsPy.emplace(std::make_pair(name, value));
+    }
+
+    void setAttrPy(const std::string& name, py::object&& value)
+    {
+        auto resPy = mAttrsPy.emplace(std::make_pair(name, value));
+        if (!resPy.second)
+            resPy.first->second = std::move(value);
     }
+#endif
 
     //////////////////////////////////////
     ///     Generic Attributes API
     //////////////////////////////////////
     bool hasAttr(const std::string& name) const override final {
+#ifdef PYBIND
+        return (mAttrsPy.find(name) != mAttrsPy.end());
+#else
         return (mAttrs.find(name) != mAttrs.end());
+#endif
     }
 
     std::string getAttrType(const std::string& name) const override final {
+        // In order to remain consistent between C++ and Python, with or without PyBind, the name of the type is:
+        // - C-style for C++ created attributes
+        // - Python-style for Python created attributes
+#ifdef PYBIND
+        // If attribute does not exist in C++, it might have been created in Python
+        auto it = mAttrs.find(name);
+        if (it == mAttrs.end()) {
+            auto itPy = mAttrsPy.find(name);
+            if (itPy != mAttrsPy.end()) {
+                return std::string(Py_TYPE(itPy->second.ptr())->tp_name);
+            }
+        }
+#endif
+
         return mAttrs.at(name).type().name();
     }
 
     std::vector<std::string> getAttrsName() const override final {
         std::vector<std::string> attrsName;
+#ifdef PYBIND
+        for(auto const& it: mAttrsPy)
+            attrsName.push_back(it.first);
+#else
         for(auto const& it: mAttrs)
             attrsName.push_back(it.first);
+#endif
         return attrsName;
     }
 
-    #ifdef PYBIND
+#ifdef PYBIND
     /**
      * @detail See https://github.com/pybind/pybind11/issues/1590 as to why a
      * generic type caster for std::any is not feasable.
+     * The strategy here is to keep a copy of each attribute in py::object that is updated everytime.
     */
     py::object getAttrPy(const std::string& name) const {
-        py::object res = py::none();
-        const auto& attrType = mAttrs.at(name).type();
-        if(attrType == typeid(int))
-            res = py::cast(getAttr<int>(name));
-        else if(attrType == typeid(float))
-            res = py::cast(getAttr<float>(name));
-        else if(attrType == typeid(bool))
-            res = py::cast(getAttr<bool>(name));
-        else if(attrType == typeid(std::string))
-            res = py::cast(getAttr<std::string>(name));
-        else if(attrType == typeid(std::vector<bool>))
-            res = py::cast(getAttr<std::vector<bool>>(name));
-        else if(attrType == typeid(std::vector<int>))
-            res = py::cast(getAttr<std::vector<int>>(name));
-        else if(attrType == typeid(std::vector<float>))
-            res = py::cast(getAttr<std::vector<float>>(name));
-        else if(attrType == typeid(std::vector<std::string>))
-            res = py::cast(getAttr<std::vector<std::string>>(name));
-        else {
-            throw py::key_error("Failed to convert attribute type " + name + ", this issue may come from typeid function which gave an unknown key : [" + attrType.name() + "]. Please open an issue asking to add the support for this key.");
-        }
-        return res;
+        return mAttrsPy.at(name);
     };
-    #endif
+#endif
 
 private:
+    // Stores C++ attributes only
     std::map<std::string, libany::any> mAttrs;
+
+#ifdef PYBIND
+    // Stores C++ attributes (copy) and Python-only attributes
+    std::map<std::string, py::object> mAttrsPy;
+#endif
 };
 
 }
diff --git a/python_binding/utils/pybind_Parameter.cpp b/python_binding/utils/pybind_Parameter.cpp
index 6f6d9980a..5ed624320 100644
--- a/python_binding/utils/pybind_Parameter.cpp
+++ b/python_binding/utils/pybind_Parameter.cpp
@@ -6,30 +6,14 @@ namespace py = pybind11;
 namespace Aidge {
 void init_Attributes(py::module& m){
     py::class_<Attributes, std::shared_ptr<Attributes>>(m, "Attributes")
-    .def("has_attr", &Attributes::hasAttr)
-    .def("get_attr_type", &Attributes::getAttrType)
+    .def("has_attr", &Attributes::hasAttr, py::arg("name"))
+    .def("get_attr_type", &Attributes::getAttrType, py::arg("name"))
     .def("get_attrs_name", &Attributes::getAttrsName)
     .def("get_attr", &Attributes::getAttrPy, py::arg("name"));
 
     py::class_<DynamicAttributes, std::shared_ptr<DynamicAttributes>, Attributes>(m, "DynamicAttributes")
-    // add
-    .def("add_attr", &DynamicAttributes::addAttr<bool>)
-    .def("add_attr", &DynamicAttributes::addAttr<int>)
-    .def("add_attr", &DynamicAttributes::addAttr<float>)
-    .def("add_attr", &DynamicAttributes::addAttr<std::string>)
-    .def("add_attr", &DynamicAttributes::addAttr<std::vector<bool>>)
-    .def("add_attr", &DynamicAttributes::addAttr<std::vector<int>>)
-    .def("add_attr", &DynamicAttributes::addAttr<std::vector<float>>)
-    .def("add_attr", &DynamicAttributes::addAttr<std::vector<std::string>>)
-    // set
-    .def("set_attr", &DynamicAttributes::setAttr<bool>)
-    .def("set_attr", &DynamicAttributes::setAttr<int>)
-    .def("set_attr", &DynamicAttributes::setAttr<float>)
-    .def("set_attr", &DynamicAttributes::setAttr<std::string>)
-    .def("set_attr", &DynamicAttributes::setAttr<std::vector<bool>>)
-    .def("set_attr", &DynamicAttributes::setAttr<std::vector<int>>)
-    .def("set_attr", &DynamicAttributes::setAttr<std::vector<float>>)
-    .def("set_attr", &DynamicAttributes::setAttr<std::vector<std::string>>);
+    .def("add_attr", &DynamicAttributes::addAttrPy, py::arg("name"), py::arg("value"))
+    .def("set_attr", &DynamicAttributes::setAttrPy, py::arg("name"), py::arg("value"));
 }
 }
 
-- 
GitLab