From ff65e20b241f5bfe9addf904c0de4067220d4232 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Tue, 30 Apr 2024 11:48:12 +0200
Subject: [PATCH] Fixed multiple outputs support for GraphView::replace()

---
 include/aidge/graph/GraphView.hpp |  8 ++++++
 src/graph/GraphView.cpp           | 43 +++++++++++++++++--------------
 2 files changed, 32 insertions(+), 19 deletions(-)

diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp
index 1a1714272..c9a4c11d7 100644
--- a/include/aidge/graph/GraphView.hpp
+++ b/include/aidge/graph/GraphView.hpp
@@ -486,6 +486,14 @@ public:
      */
     IOIndex_t getNbFreeDataInputs() const;
 
+    /**
+     * @brief Force update of GraphView inputs/outputs.
+     * It may be necessary to force the update of GraphView inputs/outputs when
+     * connections are added or removed inside the GraphView **after** the nodes
+     * were added.
+     */
+    void updateInputsOutputs();
+
 private:
 ///////////////////////////////////////////////////////
 //        TENSOR MANAGEMENT
diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp
index b748bd4bc..273825172 100644
--- a/src/graph/GraphView.cpp
+++ b/src/graph/GraphView.cpp
@@ -910,7 +910,7 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const
                                                      newGraph->getOrderedOutputs();
 
     auto inputParents = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOIn.size());
-    auto outputChildren = std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(oldOOut.size());
+    auto outputChildren = std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>>(oldOOut.size());
 
     // keep in memory every node related to the node to replace :
     // Parent
@@ -921,19 +921,12 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const
         // inputParent.first -> addChild(newOI[i].first, inputParent.second, newOI[i].second);
     }
     // Children
-    for (std::size_t i = 0; i < oldOOut.size();) {
+    for (std::size_t i = 0; i < oldOOut.size(); ++i) {
         std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> outputChild =
               oldOOut[i].first -> output(oldOOut[i].second);
-        if (outputChild.empty()) {
-            outputChildren[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>({nullptr, gk_IODefaultIndex});
-            ++i;
-        }
-        else {
-            for (const auto& child : outputChild) {
-                if (oldNodes.find(child.first) == oldNodes.cend()) {
-                    outputChildren[i] = child;
-                    ++i;
-                }
+        for (const auto& child : outputChild) {
+            if (oldNodes.find(child.first) == oldNodes.cend()) {
+                outputChildren[i].push_back(child);
             }
         }
     }
@@ -971,8 +964,8 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const
             }
         }
         for (std::size_t o = 0; o < oldOOut.size(); ++o) {
-            if (outputChildren[o].first) {
-                newOOut[o].first -> addChild(outputChildren[o].first, newOOut[o].second, outputChildren[o].second);
+            for (const auto child : outputChildren[o]) {
+                newOOut[o].first -> addChild(child.first, newOOut[o].second, child.second);
             }
         }
     }
@@ -982,15 +975,21 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const
         if (newNodes.size() == 0) {
             // Case 3
             if (oldOIn.size() == oldOOut.size()) {
+                // Same number of inputs and outputs: connect each input to the corresponding output
                 for (std::size_t i = 0; i < oldOIn.size(); ++i) {
                     if (inputParents[i].first) {
-                      inputParents[i].first -> addChild(outputChildren[i].first, inputParents[i].second, outputChildren[i].second);
+                      for (const auto child : outputChildren[i]) {
+                        inputParents[i].first -> addChild(child.first, inputParents[i].second, child.second);
+                      }
                     }
                 }
             }
             else if ((oldOIn.size() == 1) && (inputParents[0].first)) {
-                for (std::size_t i = 0; i < oldOIn.size(); ++i) {
-                    inputParents[0].first -> addChild(outputChildren[i].first, inputParents[0].second, outputChildren[i].second);
+                // Single input: connect the only input to all the outputs
+                for (std::size_t i = 0; i < oldOOut.size(); ++i) {
+                    for (const auto child : outputChildren[i]) {
+                        inputParents[0].first -> addChild(child.first, inputParents[0].second, child.second);
+                    }
                 }
             }
         }
@@ -1011,8 +1010,8 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const
                 }
             }
             for (std::size_t o = 0; o < oldOOut.size(); ++o) {
-                if (outputChildren[o].first) {
-                    newOOut[o].first -> addChild(outputChildren[o].first, newOOut[o].second, outputChildren[o].second);
+                for (const auto child : outputChildren[o]) {
+                    newOOut[o].first -> addChild(child.first, newOOut[o].second, child.second);
                 }
             }
         }
@@ -1061,6 +1060,12 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const
     return true;
 }
 
+void Aidge::GraphView::updateInputsOutputs() {
+  for (auto node : mNodes) {
+    updateInputsOutputsNew(node);
+  }
+}
+
 void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) {
   // Can be called several times with the same node, e.g. when addChild() is
   // called on a node already part of the GraphView. In this case, inputs/outputs
-- 
GitLab