From 819554a15ec1154cdd1b45f96c0b03333a46c840 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Fri, 29 Nov 2024 16:31:17 +0000 Subject: [PATCH] Fix: attribute access for 'MetaOperator' --- include/aidge/operator/MetaOperator.hpp | 7 +++++-- src/operator/MetaOperator.cpp | 20 ++++---------------- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index b915cb8f1..47eb6cf97 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -36,7 +36,10 @@ public: std::shared_ptr<SequentialScheduler> mScheduler; std::weak_ptr<Node> mUpperNode; - public: +private: + const std::shared_ptr<DynamicAttributes> mAttributes = std::make_shared<DynamicAttributes>(); + +public: MetaOperator_Op(const std::string& type, const std::shared_ptr<GraphView>& graph, const std::vector<InputCategory>& forcedInputsCategory = {}); /** @@ -92,7 +95,7 @@ public: mGraph->setDataType(datatype); } - std::shared_ptr<Attributes> attributes() const override; + inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; } Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override; Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override; diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index 060c72548..cd307c9d1 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -48,6 +48,10 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const std::string& type, const std::shar mOutputs[outputIdx] = std::dynamic_pointer_cast<Tensor>(outputOp.first->getOperator()->getRawOutput(outputOp.second)); } } + + for (const auto& node : mGraph->getRankedNodesName("{1}_{3}")) { + mAttributes->addAttr(node.second, node.first->getOperator()->attributes()); + } } std::shared_ptr<Aidge::Operator> Aidge::MetaOperator_Op::clone() const { @@ -119,22 +123,6 @@ std::set<std::string> Aidge::MetaOperator_Op::getAvailableBackends() const { return backendsList; } -std::shared_ptr<Aidge::Attributes> Aidge::MetaOperator_Op::attributes() const { - auto attrs = std::make_shared<DynamicAttributes>(); - - for (const auto& node : mGraph->getRankedNodesName("{3}")) { - const auto attributes = node.first->getOperator()->attributes(); - if (attributes) { - const auto nodeAttrs = DynamicAttributes(attributes->getAttrs()); - attrs->addAttr(node.first->type() + "#" + node.second, nodeAttrs); - if (node.second == "0") { - attrs->addAttr(node.first->type(), nodeAttrs); - } - } - } - - return attrs; -} Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const { if (mImpl) { -- GitLab