From 9e4cdbb931176b56f25d512c8623568240c61b59 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Tue, 3 Oct 2023 17:21:01 +0200 Subject: [PATCH] Corrected attributes binding --- .../unit_tests/test_operator_binding.py | 6 +++ include/aidge/utils/Attributes.hpp | 4 +- include/aidge/utils/DynamicAttributes.hpp | 42 +++++++++++++------ include/aidge/utils/StaticAttributes.hpp | 6 +-- python_binding/utils/pybind_Parameter.cpp | 3 +- 5 files changed, 42 insertions(+), 19 deletions(-) diff --git a/aidge_core/unit_tests/test_operator_binding.py b/aidge_core/unit_tests/test_operator_binding.py index 96544ecda..344086871 100644 --- a/aidge_core/unit_tests/test_operator_binding.py +++ b/aidge_core/unit_tests/test_operator_binding.py @@ -31,7 +31,13 @@ class test_operator_binding(unittest.TestCase): def test_param_bool(self): self.generic_operator.add_attr("bool", True) + self.assertEqual(self.generic_operator.has_attr("bool"), True) self.assertEqual(self.generic_operator.get_attr("bool"), True) + self.assertEqual(self.generic_operator.get_attr_type("bool"), "bool") + self.assertEqual(self.generic_operator.get_attrs_name(), {"bool"}) + self.generic_operator.del_attr("bool") + self.assertEqual(self.generic_operator.has_attr("bool"), False) + self.assertEqual(len(self.generic_operator.get_attrs_name()), 0) def test_param_int(self): self.generic_operator.add_attr("int", 1) diff --git a/include/aidge/utils/Attributes.hpp b/include/aidge/utils/Attributes.hpp index efa923ed4..76875f15f 100644 --- a/include/aidge/utils/Attributes.hpp +++ b/include/aidge/utils/Attributes.hpp @@ -58,9 +58,9 @@ public: /** * @brief Get the attribute's name list. - * @return std::vector<std::string> Vector of names of the attributes. + * @return std::set<std::string> Vector of names of the attributes. */ - virtual std::vector<std::string> getAttrsName() const = 0; + virtual std::set<std::string> getAttrsName() const = 0; #ifdef PYBIND /* Bindable get function, does not recquire any templating. diff --git a/include/aidge/utils/DynamicAttributes.hpp b/include/aidge/utils/DynamicAttributes.hpp index 41e2da940..0e91f0a53 100644 --- a/include/aidge/utils/DynamicAttributes.hpp +++ b/include/aidge/utils/DynamicAttributes.hpp @@ -25,6 +25,7 @@ #ifdef PYBIND #include <pybind11/pybind11.h> #include <pybind11/stl.h> +#include <pybind11/embed.h> namespace py = pybind11; #endif @@ -89,8 +90,11 @@ public: assert(res.second && "attribute already exists"); #ifdef PYBIND - // Keep a copy of the attribute in py::object that is updated everytime - mAttrsPy.emplace(std::make_pair(name, py::cast(value))); + // We cannot handle Python object if the Python interpreter is not running + if (Py_IsInitialized()) { + // Keep a copy of the attribute in py::object that is updated everytime + mAttrsPy.emplace(std::make_pair(name, py::cast(std::forward<T>(value)))); + } #endif } @@ -105,10 +109,20 @@ public: res.first->second = std::move(libany::any(std::forward<T>(value))); #ifdef PYBIND - // Keep a copy of the attribute in py::object that is updated everytime - auto resPy = mAttrsPy.emplace(std::make_pair(name, py::cast(value))); - if (!resPy.second) - resPy.first->second = std::move(py::cast(value)); + // We cannot handle Python object if the Python interpreter is not running + if (Py_IsInitialized()) { + // Keep a copy of the attribute in py::object that is updated everytime + auto resPy = mAttrsPy.emplace(std::make_pair(name, py::cast(std::forward<T>(value)))); + if (!resPy.second) + resPy.first->second = std::move(py::cast(std::forward<T>(value))); + } +#endif + } + + void delAttr(const std::string& name) { + mAttrs.erase(name); +#ifdef PYBIND + mAttrsPy.erase(name); #endif } @@ -131,7 +145,8 @@ public: ////////////////////////////////////// bool hasAttr(const std::string& name) const override final { #ifdef PYBIND - return (mAttrsPy.find(name) != mAttrsPy.end()); + // Attributes might have been created in Python, the second condition is necessary. + return (mAttrs.find(name) != mAttrs.end() || mAttrsPy.find(name) != mAttrsPy.end()); #else return (mAttrs.find(name) != mAttrs.end()); #endif @@ -155,14 +170,14 @@ public: return mAttrs.at(name).type().name(); } - std::vector<std::string> getAttrsName() const override final { - std::vector<std::string> attrsName; + std::set<std::string> getAttrsName() const override final { + std::set<std::string> attrsName; + for(auto const& it: mAttrs) + attrsName.insert(it.first); #ifdef PYBIND + // Attributes might have been created in Python for(auto const& it: mAttrsPy) - attrsName.push_back(it.first); -#else - for(auto const& it: mAttrs) - attrsName.push_back(it.first); + attrsName.insert(it.first); #endif return attrsName; } @@ -185,6 +200,7 @@ private: // See https://pybind11.readthedocs.io/en/stable/faq.html: // “‘SomeClass’ declared with greater visibility than the type of its // field ‘SomeClass::member’ [-Wattributes]†+ // This map will only be populated if Python interpreter is running std::map<std::string, py::object> mAttrsPy; // Stores C++ attributes only // mutable because it may be updated in getAttr() from Python diff --git a/include/aidge/utils/StaticAttributes.hpp b/include/aidge/utils/StaticAttributes.hpp index 5a00d4ad9..248fef3f3 100644 --- a/include/aidge/utils/StaticAttributes.hpp +++ b/include/aidge/utils/StaticAttributes.hpp @@ -162,10 +162,10 @@ public: assert(false && "attribute not found"); } - std::vector<std::string> getAttrsName() const override final { - std::vector<std::string> attrsName; + std::set<std::string> getAttrsName() const override final { + std::set<std::string> attrsName; for (std::size_t i = 0; i < size(EnumStrings<ATTRS_ENUM>::data); ++i) { - attrsName.push_back(EnumStrings<ATTRS_ENUM>::data[i]); + attrsName.insert(EnumStrings<ATTRS_ENUM>::data[i]); } return attrsName; } diff --git a/python_binding/utils/pybind_Parameter.cpp b/python_binding/utils/pybind_Parameter.cpp index 5ed624320..a15a7fc2f 100644 --- a/python_binding/utils/pybind_Parameter.cpp +++ b/python_binding/utils/pybind_Parameter.cpp @@ -13,7 +13,8 @@ void init_Attributes(py::module& m){ py::class_<DynamicAttributes, std::shared_ptr<DynamicAttributes>, Attributes>(m, "DynamicAttributes") .def("add_attr", &DynamicAttributes::addAttrPy, py::arg("name"), py::arg("value")) - .def("set_attr", &DynamicAttributes::setAttrPy, py::arg("name"), py::arg("value")); + .def("set_attr", &DynamicAttributes::setAttrPy, py::arg("name"), py::arg("value")) + .def("del_attr", &DynamicAttributes::delAttr, py::arg("name")); } } -- GitLab