diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 163ea35c716cd6948c998f1b08f9f07d28fe1940..77ca0b00c40e578f45834a16da65ae37ac4b7d3c 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -31,6 +31,7 @@ #include "aidge/operator/MetaOperator.hpp" #include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/Producer.hpp" +#include "aidge/operator/Memorize.hpp" #include "aidge/utils/Directories.hpp" #include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/Types.h" @@ -425,22 +426,68 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ } } - // Compute dimensions of every node - std::set<std::shared_ptr<Node>> listNodes = getNodes(); + // List of nodes that are already dims forwarded + std::set<std::shared_ptr<Node>> dimsForwarded; + // Establish initial list of dims forwardable nodes: + // input nodes and childs from Producers + std::set<std::shared_ptr<Node>> listNodes = inputNodes(); + for (const auto& nodePtr : getNodes()) { + if (nodePtr->type() == Producer_Op::Type) { + // Producers are already dims forwarded! + dimsForwarded.insert(nodePtr); + // Producers childs are dims forwardable + for (const auto& child : nodePtr->getChildren()) { + if (inView(child)) { + listNodes.insert(child); + } + } + } + } + do { std::set<std::shared_ptr<Node>> nextList; - for (std::shared_ptr<Node> nodePtr : listNodes) { + for (const auto& nodePtr : listNodes) { if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) { - const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator()); - // Recompute everytime, even if it was already computed in a - // previous call of forwardDims(), as the graph may have changed! - op->forwardDims(allowDataDependency); - if (!op->dimsForwarded()) { - nextList.insert(nodePtr); - } + const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator()); + + bool anyParent = false; + bool parentsForwarded = true; + for (const auto& parent : nodePtr->getParents()) { + if (parent != nullptr && inView(parent) && dimsForwarded.find(parent) == dimsForwarded.end()) { + Log::debug("Dimensions not forwarded for parent (node {} (of type {})) of node {} (of type {})", + parent->name(), parent->type(), nodePtr->name(), nodePtr->type()); + parentsForwarded = false; + } + else { + anyParent = true; + } + } + + // Special rule for Memorize_Op, which only requires one parent + // to have its dims forwarded. This avoids circular dependency. + if (nodePtr->type() == Memorize_Op::Type && anyParent) { + parentsForwarded = true; + } + + if (parentsForwarded && op->forwardDims(allowDataDependency)) { + // Recompute everytime, even if it was already computed in a + // previous call of forwardDims(), as the graph may have changed! + dimsForwarded.insert(nodePtr); + for (const auto& child : nodePtr->getChildren()) { + if (inView(child) && dimsForwarded.find(child) == dimsForwarded.end()) { + nextList.insert(child); + } + } + } + else { + Log::debug("Unable to forward dimensions for node {} (of type {}) yet", nodePtr->name(), nodePtr->type()); + nextList.insert(nodePtr); + } } } + Log::debug("********************"); + // Internal check to make sure we won't enter in an infinite loop! if (nextList == listNodes) { // We are stuck! @@ -452,7 +499,6 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ Log::warn("Unable to forward dimensions (circular dependency and/or wrong dimensions and/or data dependent dimension?). Unable to compute output dims for nodes {}.", nodesName); return false; } - listNodes.swap(nextList); } while (!listNodes.empty());