From a677e993271bc1f50fefbeb688060555fc8d1c4f Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Tue, 3 Dec 2024 14:54:02 +0100 Subject: [PATCH] Fixed issue --- include/aidge/utils/DynamicAttributes.hpp | 27 +++++++++++++++++++++++ include/aidge/utils/StaticAttributes.hpp | 7 +++++- src/operator/GenericOperator.cpp | 3 ++- unit_tests/recipes/Test_ToGenericOp.cpp | 13 ++++++----- 4 files changed, 43 insertions(+), 7 deletions(-) diff --git a/include/aidge/utils/DynamicAttributes.hpp b/include/aidge/utils/DynamicAttributes.hpp index 16d5d94a1..dc169fb86 100644 --- a/include/aidge/utils/DynamicAttributes.hpp +++ b/include/aidge/utils/DynamicAttributes.hpp @@ -34,6 +34,12 @@ namespace py = pybind11; namespace Aidge { +// Detection idiom to check if a type T has a less-than operator +template <typename T, typename = void> +struct has_less_than_operator : std::false_type {}; + +template <typename T> +struct has_less_than_operator<T, std::void_t<decltype(std::declval<T>() < std::declval<T>())>> : std::true_type {}; ///\todo store also a fix-sized code that indicates the type ///\todo managing complex types or excluding non-trivial, non-aggregate types @@ -344,6 +350,14 @@ public: } }; + template<typename T> + static inline typename std::enable_if<!has_less_than_operator<T>::value, void>::type makeTypeConditionallyAvailable() {} + + template<typename T> + static inline typename std::enable_if<has_less_than_operator<T>::value, void>::type makeTypeConditionallyAvailable() { + mAnyUtils.emplace(typeid(T), std::unique_ptr<AnyUtils<T>>(new AnyUtils<T>())); + } + // Stores typed utils functions for each attribute type ever used static std::map<std::type_index, std::unique_ptr<AnyUtils_>> mAnyUtils; }; @@ -407,6 +421,19 @@ namespace std { return seed; } }; + + // Special case for std::array + template <typename T, std::size_t N> + struct hash<std::array<T, N>> { + std::size_t operator()(const std::array<T, N>& iterable) const { + std::size_t seed = 0; + for (const auto& v : iterable) { + // Recursively hash the value pointed by the iterator + Aidge::hash_combine(seed, std::hash<T>()(v)); + } + return seed; + } + }; } namespace future_std { diff --git a/include/aidge/utils/StaticAttributes.hpp b/include/aidge/utils/StaticAttributes.hpp index 439d2c638..9c18d3cef 100644 --- a/include/aidge/utils/StaticAttributes.hpp +++ b/include/aidge/utils/StaticAttributes.hpp @@ -24,6 +24,7 @@ #endif #include "aidge/utils/Attributes.hpp" +#include "aidge/utils/DynamicAttributes.hpp" #include "aidge/utils/ErrorHandling.hpp" namespace Aidge { @@ -322,7 +323,11 @@ private: inline typename std::enable_if<I == sizeof...(Tp), void>::type appendAttr(const std::tuple<Tp...>& /*t*/, std::map<std::string, future_std::any>& /*attrs*/) const {} template<std::size_t I = 0, typename... Tp> - inline typename std::enable_if<I < sizeof...(Tp), void>::type appendAttr(const std::tuple<Tp...>& t, std::map<std::string, future_std::any>& attrs) const { + inline typename std::enable_if<I < sizeof...(Tp), void>::type appendAttr(const std::tuple<Tp...>& t, std::map<std::string, future_std::any>& attrs) const { + // Ensure that the type will be known to DynamicAttributes + using ElementType = typename std::tuple_element<I,std::tuple<Tp...>>::type; + DynamicAttributes::makeTypeConditionallyAvailable<ElementType>(); + attrs.insert(std::make_pair(EnumStrings<ATTRS_ENUM>::data[I], future_std::any(std::get<I>(t)))); appendAttr<I + 1, Tp...>(t, attrs); } diff --git a/src/operator/GenericOperator.cpp b/src/operator/GenericOperator.cpp index b24c35352..1e28cf289 100644 --- a/src/operator/GenericOperator.cpp +++ b/src/operator/GenericOperator.cpp @@ -22,7 +22,8 @@ Aidge::GenericOperator_Op::GenericOperator_Op(const std::string& type, const std::vector<Aidge::InputCategory>& inputsCategory, Aidge::IOIndex_t nbOut) - : OperatorTensor(type, inputsCategory, nbOut) + : OperatorTensor(type, inputsCategory, nbOut), + mAttributes(std::make_shared<DynamicAttributes>()) { mImpl = std::make_shared<OperatorImpl>(*this); } diff --git a/unit_tests/recipes/Test_ToGenericOp.cpp b/unit_tests/recipes/Test_ToGenericOp.cpp index 53ad86e7c..886e07f95 100644 --- a/unit_tests/recipes/Test_ToGenericOp.cpp +++ b/unit_tests/recipes/Test_ToGenericOp.cpp @@ -50,8 +50,9 @@ TEST_CASE("[graph/convert] toGenericOp", "[toGenericOp][recipies]") { // Ensure the conversion REQUIRE(newGenOp->type() == "Conv2D"); - REQUIRE(newGenOp->getOperator()->attributes() == convOp->getOperator()->attributes()); - + const auto convOpAttr = convOp->getOperator()->attributes()->getAttrs(); + const auto newGenOpAttr = (newGenOp->getOperator()->attributes()->getAttrs()); + REQUIRE((!(newGenOpAttr < convOpAttr) && !(convOpAttr < newGenOpAttr))); } SECTION("Test MetaOperator to Generic Operator") { @@ -60,14 +61,14 @@ TEST_CASE("[graph/convert] toGenericOp", "[toGenericOp][recipies]") { REQUIRE(nbFused == 1); - std::shared_ptr<Node> MetaOpNode; + std::shared_ptr<Node> metaOpNode; for (const auto& nodePtr : g->getNodes()) { if (nodePtr->type() == "ConvReLUFC") { nodePtr->setName("ConvReLUFC_0"); - MetaOpNode = nodePtr; + metaOpNode = nodePtr; // Convert to GenericOperator toGenericOp(nodePtr); } @@ -78,7 +79,9 @@ TEST_CASE("[graph/convert] toGenericOp", "[toGenericOp][recipies]") { // Ensure the conversion REQUIRE(newGenOp->type() == "ConvReLUFC"); - REQUIRE(newGenOp->getOperator()->attributes() == MetaOpNode->getOperator()->attributes()); + const auto metaOpAttr = *std::static_pointer_cast<DynamicAttributes>(metaOpNode->getOperator()->attributes()); + const auto newGenOpAttr = *std::static_pointer_cast<DynamicAttributes>(newGenOp->getOperator()->attributes()); + REQUIRE((!(newGenOpAttr < metaOpAttr) && !(metaOpAttr < newGenOpAttr))); } -- GitLab