diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 5c7a524af6c7a9e59bc22f29885b2be9906b3591..66421886f136ce77399b9bda54bc15fea91b0bc3 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -19,6 +19,7 @@ #include "aidge/graph/GraphView.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Producer.hpp" #include "aidge/utils/ErrorHandling.hpp" /////////////////////////////////////////////////////// @@ -266,6 +267,8 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType } void Aidge::GraphView::forwardDims() { + std::set<NodePtr> startNodes = inputNodes(); + // setInputs // Link every tensor to the right pointer // following parent - children informations @@ -288,9 +291,14 @@ void Aidge::GraphView::forwardDims() { } } + + if (nodePtr->type() == Producer_Op::Type) { + startNodes.insert(nodePtr); + } } // Compute dimensions of every node - _forwardDims(inputNodes()); + _forwardDims(startNodes); + } void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) { @@ -325,6 +333,7 @@ void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) { } } + // Internal check to make sure we won't enter in an infinite loop! AIDGE_INTERNAL_ASSERT(nextList != listNodes); if (!nextList.empty()) {