From 4ea1c64aabb3d981fbc2e2dc5a13672a16e824f6 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Wed, 4 Oct 2023 09:22:41 +0200 Subject: [PATCH] Fixed bug with setAttrPy --- aidge_core/unit_tests/test_operator_binding.py | 2 ++ include/aidge/utils/DynamicAttributes.hpp | 13 ++++++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/aidge_core/unit_tests/test_operator_binding.py b/aidge_core/unit_tests/test_operator_binding.py index 437c10ad8..fc60f5227 100644 --- a/aidge_core/unit_tests/test_operator_binding.py +++ b/aidge_core/unit_tests/test_operator_binding.py @@ -89,6 +89,8 @@ class test_operator_binding(unittest.TestCase): # Check that added Python attribute is accessible in C++ # Return the value of an attribute named "d" of type float64 (double in C++) self.assertEqual(aidge_core.test_DynamicAttributes_binding_check(attrs), 18.56) + attrs.set_attr("d", 23.89) + self.assertEqual(aidge_core.test_DynamicAttributes_binding_check(attrs), 23.89) def test_compute_output_dims(self): in_dims=[25, 25] diff --git a/include/aidge/utils/DynamicAttributes.hpp b/include/aidge/utils/DynamicAttributes.hpp index 24800428f..77645b61b 100644 --- a/include/aidge/utils/DynamicAttributes.hpp +++ b/include/aidge/utils/DynamicAttributes.hpp @@ -48,7 +48,7 @@ public: template<class T> T& getAttr(const std::string& name) { #ifdef PYBIND - // If attribute does not exist in C++, it might have been created in Python + // If attribute does not exist in C++, it might have been created or modified in Python auto it = mAttrs.find(name); if (it == mAttrs.end()) { auto itPy = mAttrsPy.find(name); @@ -66,7 +66,7 @@ public: template<class T> const T& getAttr(const std::string& name) const { #ifdef PYBIND - // If attribute does not exist in C++, it might have been created in Python + // If attribute does not exist in C++, it might have been created or modified in Python auto it = mAttrs.find(name); if (it == mAttrs.end()) { auto itPy = mAttrsPy.find(name); @@ -129,7 +129,11 @@ public: #ifdef PYBIND void addAttrPy(const std::string& name, py::object&& value) { - mAttrsPy.emplace(std::make_pair(name, value)); + auto it = mAttrs.find(name); + assert(it == mAttrs.end() && "attribute already exists"); + + const auto& res = mAttrsPy.emplace(std::make_pair(name, value)); + assert(res.second && "attribute already exists"); } void setAttrPy(const std::string& name, py::object&& value) @@ -137,6 +141,9 @@ public: auto resPy = mAttrsPy.emplace(std::make_pair(name, value)); if (!resPy.second) resPy.first->second = std::move(value); + + // Force getAttr() to take attribute value from mAttrsPy and update mAttrs + mAttrs.erase(name); } #endif -- GitLab