From 59fb7f6e4a4383c4c389dd4a63fb64a285a86465 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Wed, 6 Mar 2024 18:18:55 +0100 Subject: [PATCH] Fix issues with forwardDims() --- include/aidge/graph/GraphView.hpp | 1 - src/graph/GraphView.cpp | 74 +++++++++++-------------------- 2 files changed, 25 insertions(+), 50 deletions(-) diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 4194ed4d3..46fa56ef0 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -523,7 +523,6 @@ private: // TOPOLOGY /////////////////////////////////////////////////////// - void _forwardDims(std::set<NodePtr> listNodes); }; /** diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 42eb410fd..005a7e679 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -328,8 +328,6 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType } void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>> dims) { - std::set<NodePtr> startNodes = inputNodes(); - // setInputs // Link every tensor to the right pointer // following parent - children informations @@ -340,7 +338,8 @@ void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ mInputNodes[i].first->getOperator()->setInput(mInputNodes[i].second, tensor); } } - + + // Ensure every node in the graph is correctly connected for (std::shared_ptr<Node> nodePtr : getNodes()) { for (IOIndex_t i = 0; i < nodePtr->nbInputs(); ++i) { // assess if the input was not already set and is a Tensor then link it to parent output @@ -362,60 +361,37 @@ void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ } } - - if (nodePtr->type() == Producer_Op::Type) { - startNodes.insert(nodePtr); - } } - // Compute dimensions of every node - _forwardDims(startNodes); -} - -void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) { - // TODO: support multi-inputs/outputs - std::set<std::shared_ptr<Node>> nextList = std::set<std::shared_ptr<Node>>(); - for (std::shared_ptr<Node> nodePtr : listNodes) { - if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) { - const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator()); - if (!op->outputDimsForwarded()) { - op->computeOutputDims(); - } - if (!op->outputDimsForwarded()) { // try to compute output dimensions again later - nextList.insert(nodePtr); - } else { // compute output dimensions of children - std::set<std::shared_ptr<Node>> children = nodePtr->getChildren(); - for (auto child : children) { - const auto childOp = std::static_pointer_cast<OperatorTensor>(child->getOperator()); - if (!childOp->outputDimsForwarded()) { - nextList.insert(child); - } - } - } - } - } - if (nextList.empty()) { - for (std::shared_ptr<Node> nodePtr : getNodes()) { + // Compute dimensions of every node + std::set<std::shared_ptr<Node>> listNodes = getNodes(); + do { + std::set<std::shared_ptr<Node>> nextList; + for (std::shared_ptr<Node> nodePtr : listNodes) { if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) { - if (!std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator())->outputDimsForwarded()) { - nextList.insert(nodePtr); - } + const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator()); + // Recompute everytime, even if it was already computed in a + // previous call of forwardDims(), as the graph may have changed! + op->computeOutputDims(); + if (!op->outputDimsForwarded()) { + nextList.insert(nodePtr); + } } } - } - // Internal check to make sure we won't enter in an infinite loop! - if (nextList == listNodes) { - std::vector<std::string> nodesName; - std::transform(nextList.begin(), nextList.end(), - std::back_inserter(nodesName), - [](auto val){ return val->name() + " (" + val->type() + ")"; }); - AIDGE_THROW_OR_ABORT(std::runtime_error, "Unable to forward dimensions (circular dependency and/or wrong dimensions?). Unable to compute output dims for nodes {}.", nodesName); - } + // Internal check to make sure we won't enter in an infinite loop! + if (nextList == listNodes) { + // We are stuck! + std::vector<std::string> nodesName; + std::transform(nextList.begin(), nextList.end(), + std::back_inserter(nodesName), + [](auto val){ return val->name() + " (" + val->type() + ")"; }); + AIDGE_THROW_OR_ABORT(std::runtime_error, "Unable to forward dimensions (circular dependency and/or wrong dimensions?). Unable to compute output dims for nodes {}.", nodesName); + } - if (!nextList.empty()) { - _forwardDims(nextList); + listNodes.swap(nextList); } + while (!listNodes.empty()); } void Aidge::GraphView::setBackend(const std::string &backend, DeviceIdx_t device) { -- GitLab