From 27de830d2c3d6fa5f3ead03ff7d0670ec9b3f024 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 27 Sep 2024 17:35:15 +0200 Subject: [PATCH] Added DynamicAttributes to Node --- include/aidge/graph/Node.hpp | 14 ++++++++++--- include/aidge/operator/Operator.hpp | 7 +++++++ include/aidge/utils/DynamicAttributes.hpp | 2 ++ src/backend/OperatorImpl.cpp | 11 +++++++++-- src/graph/Node.cpp | 24 +++++++++++++++++------ src/utils/DynamicAttributes.cpp | 18 +++++++++++++++++ 6 files changed, 65 insertions(+), 11 deletions(-) diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index 32932fa6f..e014b041f 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -54,7 +54,7 @@ private: return sharedA < sharedB; // shared_ptr has a valid comparison operator } }; - std::string mName; /** Name of the Node. Should be unique. */ + std::shared_ptr<DynamicAttributes> mAttrs; std::set<std::weak_ptr<GraphView>, weakCompare> mViews; /** Set of pointers to GraphView instances including this Node instance. */ const std::shared_ptr<Operator> mOperator; // Pointer to the associated Operator @@ -70,6 +70,14 @@ private: public: Node() = delete; + /** + * @brief Construct a new Node object associated with the input Operator. + * @param op Operator giving the Node its number of connections. + * @param attrs Attributes for the Node. + */ + Node(std::shared_ptr<Operator> op, std::shared_ptr<DynamicAttributes> attrs); + Node(std::shared_ptr<Operator> op, const DynamicAttributes& attrs); + /** * @brief Construct a new Node object associated with the input Operator. * @param op Operator giving the Node its number of connections. @@ -120,7 +128,7 @@ public: * @brief Name of the Node. * @return std::string */ - inline std::string name() const noexcept { return mName; } + inline std::string name() const noexcept { return (mAttrs->hasAttr("name")) ? mAttrs->getAttr<std::string>("name") : ""; } /** * @brief Set the Node name. @@ -164,7 +172,7 @@ public: * @brief Get the Operator object of the Node. * @return std::shared_ptr<Operator> */ - inline std::shared_ptr<Operator> getOperator() const { return mOperator; } + inline std::shared_ptr<Operator> getOperator() const { return (*mOperator)(mAttrs); } /////////////////////////////////////////////////////// // TENSOR MANAGEMENT diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 87aa4080e..1cfb0d92e 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -50,6 +50,7 @@ enum class InputCategory { class Operator : public std::enable_shared_from_this<Operator> { protected: std::shared_ptr<OperatorImpl> mImpl; // implementation of the operator + std::shared_ptr<DynamicAttributes> mInheritedAttrs; std::map<std::string, std::shared_ptr<Hook>> mHooks; private: @@ -84,12 +85,18 @@ public: // Hooks are not copied. } + std::shared_ptr<Operator> operator()(std::shared_ptr<DynamicAttributes> attrs) { + mInheritedAttrs = attrs; + return shared_from_this(); + } + virtual ~Operator() noexcept; public: virtual std::shared_ptr<Operator> clone() const = 0; virtual std::shared_ptr<Attributes> attributes() const { return nullptr; }; + virtual std::shared_ptr<DynamicAttributes> inheritedAttributes() const { return mInheritedAttrs; }; /** * @brief Set the specified input with a shallow copy. * @param inputIdx Index of the input to set. diff --git a/include/aidge/utils/DynamicAttributes.hpp b/include/aidge/utils/DynamicAttributes.hpp index 6c6f3b8d9..dc066664b 100644 --- a/include/aidge/utils/DynamicAttributes.hpp +++ b/include/aidge/utils/DynamicAttributes.hpp @@ -341,6 +341,8 @@ public: static std::map<std::type_index, std::unique_ptr<AnyUtils_>> mAnyUtils; }; +template<> void DynamicAttributes::setAttr<future_std::any>(const std::string& name, const future_std::any& value); + #ifdef PYBIND template <> struct DynamicAttributes::AnyUtils<py::object> : public DynamicAttributes::AnyUtils_ { diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp index 0fa2cfdad..e2215e704 100644 --- a/src/backend/OperatorImpl.cpp +++ b/src/backend/OperatorImpl.cpp @@ -81,6 +81,13 @@ Aidge::ImplSpec Aidge::OperatorImpl::getRequiredSpec() const { else { requiredSpec.attrs.setAttr("type", mOp.type()); } + + const auto& inhAttrs = mOp.inheritedAttributes(); + if (inhAttrs) { + if (inhAttrs->hasAttr("impl")) { + requiredSpec.attrs.setAttr("impl", inhAttrs->getAny("impl")); + } + } return requiredSpec; } @@ -120,9 +127,9 @@ Aidge::ImplSpec Aidge::OperatorImpl::getBestMatch(const ImplSpec& requiredSpecs) std::string qualifier; const auto qualifierPos = std::find_if(attrName.begin(), attrName.end(), [](char c) { return c == ':'; }); - if (qualifierPos != attrName.begin()) { + if (qualifierPos != attrName.end()) { name = attrName.substr(0, qualifierPos - attrName.begin()); - qualifier = attrName.substr(qualifierPos - attrName.begin()); + qualifier = attrName.substr(qualifierPos - attrName.begin() + 1); } const bool mandatory = (qualifier == "!"); diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index b2ceb903d..c19eab12a 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -19,8 +19,8 @@ #include "aidge/operator/Producer.hpp" #include "aidge/utils/Types.h" -Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name) - : mName(name), +Aidge::Node::Node(std::shared_ptr<Operator> op, std::shared_ptr<DynamicAttributes> attrs) + : mAttrs(attrs), mOperator(op), mParents(std::vector<std::shared_ptr<Node>>(static_cast<std::size_t>(op->nbInputs()), nullptr)), @@ -38,6 +38,18 @@ Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name) } } +Aidge::Node::Node(std::shared_ptr<Operator> op, const DynamicAttributes& attrs) + : Node(op, std::make_shared<DynamicAttributes>(attrs)) {} + +Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name) + : Node(op, DynamicAttributes()) +{ + // ctor + if (!name.empty()) { + mAttrs->setAttr<std::string>("name", name); + } +} + /////////////////////////////////////////////////////// // FUNCTIONAL DESCRIPTION /////////////////////////////////////////////////////// @@ -70,7 +82,7 @@ Aidge::Connector Aidge::Node::operator()(const std::vector<Connector>& ctors) { void Aidge::Node::setName(const std::string& name) { for (auto graphView : views()) graphView->updateNodeName(shared_from_this(), name); - mName = name; + mAttrs->setAttr<std::string>("name", name); } std::string Aidge::Node::createUniqueName(std::string baseName) @@ -399,18 +411,18 @@ void Aidge::Node::resetConnections(bool includeLearnableParam) { /////////////////////////////////////////////////////// Aidge::NodePtr Aidge::Node::cloneSharedOperators() const { - return std::make_shared<Node>(mOperator, mName); + return std::make_shared<Node>(mOperator, mAttrs); } Aidge::NodePtr Aidge::Node::cloneSharedProducers() const { std::shared_ptr<Operator> op = (mOperator->type() == Producer_Op::Type) ? mOperator : mOperator->clone(); - return std::make_shared<Node>(op, mName); + return std::make_shared<Node>(op, mAttrs); } Aidge::NodePtr Aidge::Node::clone() const { - return std::make_shared<Node>(mOperator->clone(), mName); + return std::make_shared<Node>(mOperator->clone(), mAttrs); } std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta, std::set<Aidge::NodePtr> nodeSee) { diff --git a/src/utils/DynamicAttributes.cpp b/src/utils/DynamicAttributes.cpp index facf34377..3bbf40038 100644 --- a/src/utils/DynamicAttributes.cpp +++ b/src/utils/DynamicAttributes.cpp @@ -13,6 +13,24 @@ std::map<std::type_index, std::unique_ptr<Aidge::DynamicAttributes::AnyUtils_>> Aidge::DynamicAttributes::mAnyUtils; +template<> void Aidge::DynamicAttributes::setAttr<future_std::any>(const std::string& name, const future_std::any& value) +{ + const auto dot = name.find('.'); + if (dot == name.npos) { + AIDGE_ASSERT(mAnyUtils.find(value.type()) != mAnyUtils.end(), "DynamicAttributes::setAttr(): cannot set value to std::any of never seen type."); + + auto res = mAttrs.emplace(std::make_pair(name, value)); + if (!res.second) + res.first->second = value; + } + else { + const auto ns = name.substr(0, dot); + const auto nsName = name.substr(dot + 1); + auto res = mAttrs.emplace(std::make_pair(ns, future_std::any(DynamicAttributes()))); + future_std::any_cast<DynamicAttributes&>(res.first->second).setAttr<future_std::any>(nsName, value); + } +} + bool future_std::operator<(const future_std::any& lhs, const future_std::any& rhs) { if (lhs.type() == rhs.type()) { return Aidge::DynamicAttributes::mAnyUtils.at(lhs.type())->compare(lhs, rhs); -- GitLab