Skip to content
Snippets Groups Projects
Commit 02668cdc authored by Maxence Naud's avatar Maxence Naud
Browse files

Get back to previous state for GenericOperator parameter serialization

parent d2a875f3
No related branches found
No related tags found
Loading
......@@ -46,12 +46,12 @@ class test_operator_binding(unittest.TestCase):
self.assertEqual(self.generic_operator.get_parameter("str"), "value")
def test_param_l_int(self):
self.generic_operator.add_parameter("l_int", [1,2,3])
self.assertEqual(self.generic_operator.get_parameter("l_int"), [1,2,3])
self.generic_operator.add_parameter("l_int", [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15])
self.assertEqual(self.generic_operator.get_parameter("l_int"), [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15])
def test_param_l_bool(self):
self.generic_operator.add_parameter("l_bool", [True, False])
self.assertEqual(self.generic_operator.get_parameter("l_bool"), [True, False])
self.generic_operator.add_parameter("l_bool", [True, False, False, True])
self.assertEqual(self.generic_operator.get_parameter("l_bool"), [True, False, False, True])
def test_param_l_float(self):
self.generic_operator.add_parameter("l_float", [2.0, 1.0])
......
......@@ -79,6 +79,7 @@ class GenericOperator_Op
mParams.Add<T>(key, value);
}
std::string getParameterType(std::string const &key) { return mParams.getParamType(key); }
std::vector<std::string> getParametersName() { return mParams.getParametersName(); }
......@@ -88,7 +89,7 @@ class GenericOperator_Op
printf("Info: using associateInput() on a GenericOperator.\n");
}
void computeOutputDims() override final {
void computeOutputDims() override final {
assert(false && "Cannot compute output dim of a GenericOperator");
}
......@@ -115,7 +116,7 @@ class GenericOperator_Op
printf("Info: using getInput() on a GenericOperator.\n");
return mInputs[inputIdx];
}
inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
assert((outputIdx < mNbOut) && "output index out of range for this instance of GenericOperator");
printf("Info: using getOutput() on a GenericOperator.\n");
return mOutputs[outputIdx];
......
......@@ -15,8 +15,6 @@
#include <assert.h>
#include <map>
#include <vector>
#include <numeric>
#include <cstddef>
namespace Aidge {
......@@ -24,6 +22,13 @@ namespace Aidge {
///\todo managing complex types or excluding non-trivial, non-aggregate types
class CParameter
{
private:
template <typename T>
struct is_vector : std::false_type {};
template <typename T, typename Alloc>
struct is_vector<std::vector<T, Alloc>> : std::true_type {};
public:
// not copyable, not movable
CParameter(CParameter const &) = delete;
......@@ -31,6 +36,7 @@ public:
CParameter &operator=(CParameter const &) = delete;
CParameter &operator=(CParameter &&) = delete;
CParameter() : m_Params({}){};
~CParameter() = default;
/**
* \brief Returning a parameter identified by its name
......@@ -42,13 +48,13 @@ public:
* param buffer that will get invalid after the CParam death.
* \note at() throws if the parameter does not exist, using find to test for parameter existance
*/
template<class T> T Get(std::string const &i_ParamName) const
template<class T> T Get(std::string const i_ParamName) const
{
assert(m_Params.find(i_ParamName) != m_Params.end());
assert(m_Types.find(i_ParamName) != m_Types.end());
assert(m_Params.at(i_ParamName) <= m_Size);
assert(m_Params.at(i_ParamName) <= m_OffSet);
assert(typeid(T).name() == m_Types.at(i_ParamName));
return *reinterpret_cast<T *>(m_Buffer + m_Params.at(i_ParamName));
return *reinterpret_cast<T *>(m_BeginBuffer + m_Params.at(i_ParamName));
}
///\brief Add a parameter value, identified by its name
......@@ -60,24 +66,17 @@ public:
/// internal buffer in a new location (previous value is still in memory at its previous location)
template<class T> void Add(std::string const &i_ParamName, T const &i_Value)
{
const std::size_t addedSize = sizeof(T) / sizeof(std::uint8_t);
std::uint8_t *tmp = m_Buffer;
std::uint8_t *m_NewBuffer = static_cast<std::uint8_t *>(std::malloc((m_Size + addedSize)*sizeof(std::uint8_t)));
for (std::size_t i = 0; i < m_Size; ++i) {
m_NewBuffer[i] = m_Buffer[i];
}
free(tmp);
for (std::size_t i = 0; i < addedSize; ++i) {
m_NewBuffer[m_Size+i] = *(reinterpret_cast<const std::uint8_t *>(&i_Value) + i);
}
m_Buffer = m_NewBuffer;
m_Params[i_ParamName] = m_Size; // Copy pointer offset
m_Size += addedSize; // Increment offset
m_Buffer.resize(m_Buffer.size() + (sizeof(T) / sizeof(uint8_t)));
m_BeginBuffer = m_Buffer.data(); // Update buffer ptr in case of memory reordering
*reinterpret_cast<T *>(m_BeginBuffer + m_OffSet)
= i_Value; // Black-magic used to add anytype into the vector
m_Params[i_ParamName] = m_OffSet; // Copy pointer offset
m_OffSet += sizeof(T); // Increment offset
m_Types[i_ParamName] = typeid(i_Value).name();
}
std::string getParamType(std::string const &i_ParamName){
return m_Types[i_ParamName];
}
......@@ -89,27 +88,26 @@ public:
return parametersName;
}
~CParameter() {
free(m_Buffer);
}
private:
/// @brief Number of elements in m_Buffer
std::size_t m_Size = 0;
std::map<std::string, std::size_t> m_Params; // { Param name : offset }
///\brief Map to check type error
/* Note : i tried this : `std::map<std::string, std::type_info const *> m_Types;`
/* Note : i tried this : `std::map<std::string, std::type_info const *> mTypes;`
but looks like the type_ingo object was destroyed.
I am not a hugde fan of storing a string and making string comparison.
Maybe we can use a custom enum type (or is there a standard solution ?)
Maybe we can use a custom enum type (or is there a standard solution ?)
*/
std::map<std::string, std::string> m_Types;
///\brief All parameters values concatenated in raw binary form.
std::uint8_t *m_Buffer = static_cast<std::uint8_t *>(std::malloc(0));
std::vector<uint8_t> m_Buffer = {};
///\brief Starting address of the buffer
uint8_t *m_BeginBuffer = m_Buffer.data();
///\brief Offset, in number of uint8_t, of the next parameter to write
std::size_t m_OffSet = 0;
};
}
......
......@@ -22,33 +22,44 @@ TEST_CASE("[core/operators] GenericOp(add & get parameters)", "[Operator]") {
GenericOperator_Op Testop("TestOp", 1, 1, 1);
int value = 5;
const char* key = "intParam";
Testop.addParameter<int>(key, value);
Testop.addParameter(key, value);
REQUIRE(Testop.getParameter<int>(key) == value);
}
SECTION("LONG") {
GenericOperator_Op Testop("TestOp", 1, 1, 1);
long value = 3;
const char* key = "longParam";
Testop.addParameter<long>(key, value);
Testop.addParameter(key, value);
REQUIRE(Testop.getParameter<long>(key) == value);
}
SECTION("FLOAT") {
GenericOperator_Op Testop("TestOp", 1, 1, 1);
float value = 2.0;
const char* key = "floatParam";
Testop.addParameter<float>(key, value);
Testop.addParameter(key, value);
REQUIRE(Testop.getParameter<float>(key) == value);
}
SECTION("VECTOR<BOOL>") {
GenericOperator_Op Testop("TestOp", 1, 1, 1);
std::vector<bool> value = {true, false, false, true, true};
const char* key = "vect";
Testop.addParameter(key, value);
REQUIRE(Testop.getParameter<std::vector<bool>>(key).size() == value.size());
for (std::size_t i=0; i < value.size(); ++i){
REQUIRE(Testop.getParameter<std::vector<bool>>(key)[i] == value[i]);
}
}
SECTION("VECTOR<INT>") {
GenericOperator_Op Testop("TestOp", 1, 1, 1);
std::vector<int> value = {1, 2};
std::vector<int> value = {1, 2, 3, 4, 5, 6, 7, 8, 9};
const char* key = "vect";
Testop.addParameter<std::vector<int>>(key, value);
Testop.addParameter(key, value);
REQUIRE(Testop.getParameter<std::vector<int>>(key).size() == value.size());
for (std::size_t i=0; i < value.size(); ++i){
REQUIRE(Testop.getParameter<std::vector<int>>(key)[i] == value[i]);
}
}
}
SECTION("MULTIPLE PARAMS") {
/*
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment