From 59fb7f6e4a4383c4c389dd4a63fb64a285a86465 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Wed, 6 Mar 2024 18:18:55 +0100
Subject: [PATCH] Fix issues with forwardDims()

---
 include/aidge/graph/GraphView.hpp |  1 -
 src/graph/GraphView.cpp           | 74 +++++++++++--------------------
 2 files changed, 25 insertions(+), 50 deletions(-)

diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp
index 4194ed4d3..46fa56ef0 100644
--- a/include/aidge/graph/GraphView.hpp
+++ b/include/aidge/graph/GraphView.hpp
@@ -523,7 +523,6 @@ private:
     //        TOPOLOGY
     ///////////////////////////////////////////////////////
 
-    void _forwardDims(std::set<NodePtr> listNodes);
 };
 
 /**
diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp
index 42eb410fd..005a7e679 100644
--- a/src/graph/GraphView.cpp
+++ b/src/graph/GraphView.cpp
@@ -328,8 +328,6 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType
 }
 
 void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>> dims) {
-    std::set<NodePtr> startNodes = inputNodes();
-
     // setInputs
     // Link every tensor to the right pointer
     // following parent - children informations
@@ -340,7 +338,8 @@ void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_
         mInputNodes[i].first->getOperator()->setInput(mInputNodes[i].second, tensor);
       }
     }
-      
+
+    // Ensure every node in the graph is correctly connected
     for (std::shared_ptr<Node> nodePtr : getNodes()) {
         for (IOIndex_t i = 0; i < nodePtr->nbInputs(); ++i) {
             // assess if the input was not already set and is a Tensor then link it to parent output
@@ -362,60 +361,37 @@ void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_
             }
 
         }
-
-        if (nodePtr->type() == Producer_Op::Type) {
-          startNodes.insert(nodePtr);
-        }
     }
-    // Compute dimensions of every node
-    _forwardDims(startNodes);
 
-}
-
-void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) {
-    // TODO: support multi-inputs/outputs
-    std::set<std::shared_ptr<Node>> nextList = std::set<std::shared_ptr<Node>>();
-    for (std::shared_ptr<Node> nodePtr : listNodes) {
-        if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) {
-            const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator());
-            if (!op->outputDimsForwarded()) {
-                op->computeOutputDims();
-            }
-            if (!op->outputDimsForwarded()) { // try to compute output dimensions again later
-                nextList.insert(nodePtr);
-            } else { // compute output dimensions of children
-                std::set<std::shared_ptr<Node>> children = nodePtr->getChildren();
-                for (auto child : children) {
-                  const auto childOp = std::static_pointer_cast<OperatorTensor>(child->getOperator());
-                  if (!childOp->outputDimsForwarded()) {
-                    nextList.insert(child);
-                  }
-                }
-            }
-        }
-    }
-    if (nextList.empty()) {
-        for (std::shared_ptr<Node> nodePtr : getNodes()) {
+    // Compute dimensions of every node
+    std::set<std::shared_ptr<Node>> listNodes = getNodes();
+    do {
+        std::set<std::shared_ptr<Node>> nextList;
+        for (std::shared_ptr<Node> nodePtr : listNodes) {
             if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) {
-                if (!std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator())->outputDimsForwarded()) {
-                    nextList.insert(nodePtr);
-                }
+              const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator());
+              // Recompute everytime, even if it was already computed in a
+              // previous call of forwardDims(), as the graph may have changed!
+              op->computeOutputDims();
+              if (!op->outputDimsForwarded()) {
+                  nextList.insert(nodePtr);
+              }
             }
         }
-    }
 
-    // Internal check to make sure we won't enter in an infinite loop!
-    if (nextList == listNodes) {
-      std::vector<std::string> nodesName;
-      std::transform(nextList.begin(), nextList.end(),
-          std::back_inserter(nodesName),
-          [](auto val){ return val->name() + " (" + val->type() + ")"; });
-      AIDGE_THROW_OR_ABORT(std::runtime_error, "Unable to forward dimensions (circular dependency and/or wrong dimensions?). Unable to compute output dims for nodes {}.", nodesName);
-    }
+        // Internal check to make sure we won't enter in an infinite loop!
+        if (nextList == listNodes) {
+            // We are stuck!
+            std::vector<std::string> nodesName;
+            std::transform(nextList.begin(), nextList.end(),
+                std::back_inserter(nodesName),
+                [](auto val){ return val->name() + " (" + val->type() + ")"; });
+            AIDGE_THROW_OR_ABORT(std::runtime_error, "Unable to forward dimensions (circular dependency and/or wrong dimensions?). Unable to compute output dims for nodes {}.", nodesName);
+        }
 
-    if (!nextList.empty()) {
-        _forwardDims(nextList);
+        listNodes.swap(nextList);
     }
+    while (!listNodes.empty());
 }
 
 void Aidge::GraphView::setBackend(const std::string &backend, DeviceIdx_t device) {
-- 
GitLab