From e65e2059fc779312510b11022f12eaec0018d805 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Tue, 6 Feb 2024 10:46:23 +0100 Subject: [PATCH] forwardDims() improvement --- src/graph/GraphView.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 5c7a524af..66421886f 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -19,6 +19,7 @@ #include "aidge/graph/GraphView.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Producer.hpp" #include "aidge/utils/ErrorHandling.hpp" /////////////////////////////////////////////////////// @@ -266,6 +267,8 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType } void Aidge::GraphView::forwardDims() { + std::set<NodePtr> startNodes = inputNodes(); + // setInputs // Link every tensor to the right pointer // following parent - children informations @@ -288,9 +291,14 @@ void Aidge::GraphView::forwardDims() { } } + + if (nodePtr->type() == Producer_Op::Type) { + startNodes.insert(nodePtr); + } } // Compute dimensions of every node - _forwardDims(inputNodes()); + _forwardDims(startNodes); + } void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) { @@ -325,6 +333,7 @@ void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) { } } + // Internal check to make sure we won't enter in an infinite loop! AIDGE_INTERNAL_ASSERT(nextList != listNodes); if (!nextList.empty()) { -- GitLab