From 02668cdcde85b834d65848a21b3a679127e4347a Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Mon, 21 Aug 2023 10:01:43 +0000
Subject: [PATCH] Get back to previous state for GenericOperator parameter
 serialization

---
 .../unit_tests/test_operator_binding.py       |  8 +--
 include/aidge/operator/GenericOperator.hpp    |  5 +-
 include/aidge/utils/CParameter.hpp            | 62 +++++++++----------
 unit_tests/operator/Test_GenericOperator.cpp  | 23 +++++--
 4 files changed, 54 insertions(+), 44 deletions(-)

diff --git a/aidge_core/unit_tests/test_operator_binding.py b/aidge_core/unit_tests/test_operator_binding.py
index d095439d3..b326e0748 100644
--- a/aidge_core/unit_tests/test_operator_binding.py
+++ b/aidge_core/unit_tests/test_operator_binding.py
@@ -46,12 +46,12 @@ class test_operator_binding(unittest.TestCase):
         self.assertEqual(self.generic_operator.get_parameter("str"), "value")
 
     def test_param_l_int(self):
-        self.generic_operator.add_parameter("l_int", [1,2,3])
-        self.assertEqual(self.generic_operator.get_parameter("l_int"), [1,2,3])
+        self.generic_operator.add_parameter("l_int", [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15])
+        self.assertEqual(self.generic_operator.get_parameter("l_int"), [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15])
 
     def test_param_l_bool(self):
-        self.generic_operator.add_parameter("l_bool", [True, False])
-        self.assertEqual(self.generic_operator.get_parameter("l_bool"), [True, False])
+        self.generic_operator.add_parameter("l_bool", [True, False, False, True])
+        self.assertEqual(self.generic_operator.get_parameter("l_bool"), [True, False, False, True])
 
     def test_param_l_float(self):
         self.generic_operator.add_parameter("l_float", [2.0, 1.0])
diff --git a/include/aidge/operator/GenericOperator.hpp b/include/aidge/operator/GenericOperator.hpp
index 254d62c6b..86b96bfaa 100644
--- a/include/aidge/operator/GenericOperator.hpp
+++ b/include/aidge/operator/GenericOperator.hpp
@@ -79,6 +79,7 @@ class GenericOperator_Op
         mParams.Add<T>(key, value);
     }
 
+
     std::string getParameterType(std::string const &key) { return mParams.getParamType(key); }
 
     std::vector<std::string> getParametersName() { return mParams.getParametersName(); }
@@ -88,7 +89,7 @@ class GenericOperator_Op
         printf("Info: using associateInput() on a GenericOperator.\n");
     }
 
-    void computeOutputDims() override final { 
+    void computeOutputDims() override final {
         assert(false && "Cannot compute output dim of a GenericOperator");
     }
 
@@ -115,7 +116,7 @@ class GenericOperator_Op
         printf("Info: using getInput() on a GenericOperator.\n");
         return mInputs[inputIdx];
     }
-    inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { 
+    inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
         assert((outputIdx < mNbOut) && "output index out of range for this instance of GenericOperator");
         printf("Info: using getOutput() on a GenericOperator.\n");
         return mOutputs[outputIdx];
diff --git a/include/aidge/utils/CParameter.hpp b/include/aidge/utils/CParameter.hpp
index a8f18c03f..64943ff58 100644
--- a/include/aidge/utils/CParameter.hpp
+++ b/include/aidge/utils/CParameter.hpp
@@ -15,8 +15,6 @@
 #include <assert.h>
 #include <map>
 #include <vector>
-#include <numeric>
-#include <cstddef>
 
 namespace Aidge {
 
@@ -24,6 +22,13 @@ namespace Aidge {
 ///\todo managing complex types or excluding non-trivial, non-aggregate types
 class CParameter
 {
+private:
+    template <typename T>
+    struct is_vector : std::false_type {};
+
+    template <typename T, typename Alloc>
+    struct is_vector<std::vector<T, Alloc>> : std::true_type {};
+
 public:
     // not copyable, not movable
     CParameter(CParameter const &) = delete;
@@ -31,6 +36,7 @@ public:
     CParameter &operator=(CParameter const &) = delete;
     CParameter &operator=(CParameter &&) = delete;
     CParameter() : m_Params({}){};
+    ~CParameter() = default;
 
     /**
      * \brief Returning a parameter identified by its name
@@ -42,13 +48,13 @@ public:
      *  param buffer that will get invalid after the CParam death.
      * \note at() throws if the parameter does not exist, using find to test for parameter existance
      */
-    template<class T> T Get(std::string const &i_ParamName) const
+    template<class T> T Get(std::string const i_ParamName) const
     {
         assert(m_Params.find(i_ParamName) != m_Params.end());
         assert(m_Types.find(i_ParamName) != m_Types.end());
-        assert(m_Params.at(i_ParamName) <= m_Size);
+        assert(m_Params.at(i_ParamName) <= m_OffSet);
         assert(typeid(T).name() == m_Types.at(i_ParamName));
-        return *reinterpret_cast<T *>(m_Buffer + m_Params.at(i_ParamName));
+        return *reinterpret_cast<T *>(m_BeginBuffer + m_Params.at(i_ParamName));
     }
 
     ///\brief Add a parameter value, identified by its name
@@ -60,24 +66,17 @@ public:
     /// internal buffer in a new location (previous value is still in memory at its previous location)
     template<class T> void Add(std::string const &i_ParamName, T const &i_Value)
     {
-        const std::size_t addedSize = sizeof(T) / sizeof(std::uint8_t);
-        std::uint8_t *tmp = m_Buffer;
-        std::uint8_t *m_NewBuffer = static_cast<std::uint8_t *>(std::malloc((m_Size + addedSize)*sizeof(std::uint8_t)));
-
-        for (std::size_t i = 0; i < m_Size; ++i) {
-            m_NewBuffer[i] = m_Buffer[i];
-        }
-        free(tmp);
-        for (std::size_t i = 0; i < addedSize; ++i) {
-            m_NewBuffer[m_Size+i] = *(reinterpret_cast<const std::uint8_t *>(&i_Value) + i);
-        }
-        m_Buffer = m_NewBuffer;
-
-        m_Params[i_ParamName] = m_Size; // Copy pointer offset
-        m_Size += addedSize; // Increment offset
+        m_Buffer.resize(m_Buffer.size() + (sizeof(T) / sizeof(uint8_t)));
+        m_BeginBuffer = m_Buffer.data(); // Update buffer ptr in case of memory reordering
+        *reinterpret_cast<T *>(m_BeginBuffer + m_OffSet)
+            = i_Value; // Black-magic used to add anytype into the vector
+        m_Params[i_ParamName] = m_OffSet; // Copy pointer offset
+        m_OffSet += sizeof(T); // Increment offset
+
         m_Types[i_ParamName] = typeid(i_Value).name();
     }
-    
+
+
     std::string getParamType(std::string const &i_ParamName){
         return m_Types[i_ParamName];
     }
@@ -89,27 +88,26 @@ public:
         return parametersName;
     }
 
-
-    ~CParameter() {
-        free(m_Buffer);
-    }
-
 private:
-    /// @brief Number of elements in m_Buffer
-    std::size_t m_Size = 0;
-
     std::map<std::string, std::size_t> m_Params; // { Param name : offset }
 
     ///\brief Map to check type error
-    /* Note : i tried this : `std::map<std::string, std::type_info const *> m_Types;`
+    /* Note : i tried this : `std::map<std::string, std::type_info const *> mTypes;`
     but looks like the type_ingo object was destroyed.
     I am not a hugde fan of storing a string and making string comparison.
-    Maybe we can use a custom enum type (or is there a standard solution ?)  
+    Maybe we can use a custom enum type (or is there a standard solution ?)
     */
     std::map<std::string, std::string> m_Types;
 
     ///\brief All parameters values concatenated in raw binary form.
-    std::uint8_t *m_Buffer = static_cast<std::uint8_t *>(std::malloc(0));
+    std::vector<uint8_t> m_Buffer = {};
+
+    ///\brief Starting address of the buffer
+    uint8_t *m_BeginBuffer = m_Buffer.data();
+
+    ///\brief Offset, in number of uint8_t, of the next parameter to write
+    std::size_t m_OffSet = 0;
+
 };
 
 }
diff --git a/unit_tests/operator/Test_GenericOperator.cpp b/unit_tests/operator/Test_GenericOperator.cpp
index ff41ed468..886326214 100644
--- a/unit_tests/operator/Test_GenericOperator.cpp
+++ b/unit_tests/operator/Test_GenericOperator.cpp
@@ -22,33 +22,44 @@ TEST_CASE("[core/operators] GenericOp(add & get parameters)", "[Operator]") {
         GenericOperator_Op Testop("TestOp", 1, 1, 1);
         int value = 5;
         const char* key = "intParam";
-        Testop.addParameter<int>(key, value);
+        Testop.addParameter(key, value);
         REQUIRE(Testop.getParameter<int>(key) == value);
     }
     SECTION("LONG") {
         GenericOperator_Op Testop("TestOp", 1, 1, 1);
         long value = 3;
         const char* key = "longParam";
-        Testop.addParameter<long>(key, value);
+        Testop.addParameter(key, value);
         REQUIRE(Testop.getParameter<long>(key) == value);
     }
     SECTION("FLOAT") {
         GenericOperator_Op Testop("TestOp", 1, 1, 1);
         float value = 2.0;
         const char* key = "floatParam";
-        Testop.addParameter<float>(key, value);
+        Testop.addParameter(key, value);
         REQUIRE(Testop.getParameter<float>(key) == value);
+    }
+     SECTION("VECTOR<BOOL>") {
+        GenericOperator_Op Testop("TestOp", 1, 1, 1);
+        std::vector<bool> value = {true, false, false, true, true};
+        const char* key = "vect";
+        Testop.addParameter(key, value);
+
+        REQUIRE(Testop.getParameter<std::vector<bool>>(key).size() == value.size());
+        for (std::size_t i=0; i < value.size(); ++i){
+            REQUIRE(Testop.getParameter<std::vector<bool>>(key)[i] == value[i]);
+        }
     }
     SECTION("VECTOR<INT>") {
         GenericOperator_Op Testop("TestOp", 1, 1, 1);
-        std::vector<int> value = {1, 2};
+        std::vector<int> value = {1, 2, 3, 4, 5, 6, 7, 8, 9};
         const char* key = "vect";
-        Testop.addParameter<std::vector<int>>(key, value);
+        Testop.addParameter(key, value);
 
         REQUIRE(Testop.getParameter<std::vector<int>>(key).size() == value.size());
         for (std::size_t i=0; i < value.size(); ++i){
             REQUIRE(Testop.getParameter<std::vector<int>>(key)[i] == value[i]);
-        } 
+        }
     }
     SECTION("MULTIPLE PARAMS") {
         /*
-- 
GitLab