From 5b2d9b4b4059d650f265d31648cdb0916a5562cf Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Fri, 31 May 2024 17:34:44 +0200
Subject: [PATCH] Use the ordered nodes API in Sequential() in order to
 preserve the user declaration nodes order

---
 src/graph/OpArgs.cpp | 48 +++++++++++++++++++++++++++++---------------
 1 file changed, 32 insertions(+), 16 deletions(-)

diff --git a/src/graph/OpArgs.cpp b/src/graph/OpArgs.cpp
index 124878fc4..e1a378c3d 100644
--- a/src/graph/OpArgs.cpp
+++ b/src/graph/OpArgs.cpp
@@ -18,23 +18,37 @@ std::shared_ptr<Aidge::GraphView> Aidge::Sequential(std::vector<OpArgs> inputs)
     std::shared_ptr<GraphView> gv = std::make_shared<GraphView>();
     for (const OpArgs& elt : inputs) {
         if(elt.node() != nullptr) {
-            // >= to allow incomplete graphViews
-            assert(static_cast<std::size_t>(elt.node()->getNbFreeDataInputs()) >= gv->outputNodes().size());
-            /*
-            *  /!\ mn.view()->outputNodes() is a set, order of Nodes cannot be guaranted.
-            *  Prefer a functional description for detailed inputs
-            */
-            for (const std::shared_ptr<Node>& node_ptr : gv->outputNodes()) {
-                node_ptr -> addChild(elt.node()); // already checks that node_ptr->nbOutput() == 1
+            // Connect the first output (ordered) of each output node (ordered) 
+            // to the next available input of the input node.
+            AIDGE_ASSERT(static_cast<std::size_t>(elt.node()->getNbFreeDataInputs()) >= gv->outputNodes().size(),
+                "Sequential(): not enough free data inputs ({}) for input node {} (of type {}) to connect to all previous output nodes ({})",
+                elt.node()->getNbFreeDataInputs(), elt.node()->name(), elt.node()->type(), gv->outputNodes().size());
+            std::set<NodePtr> connectedOutputs;
+            for (const auto& node_out : gv->getOrderedOutputs()) {
+                if (connectedOutputs.find(node_out.first) == connectedOutputs.end()) {
+                    node_out.first -> addChild(elt.node(), node_out.second); // already checks that node_out->nbOutput() == 1
+                    connectedOutputs.insert(node_out.first);
+                }
             }
             gv->add(elt.node());
         }
         else {
-            for (std::shared_ptr<Node> node_in : elt.view()->inputNodes()) {
-                // >= to allow incomplete graphViews
-                assert(static_cast<std::size_t>(node_in->getNbFreeDataInputs()) >= gv->outputNodes().size());
-                for (std::shared_ptr<Node> node_out : gv->outputNodes()) {
-                    node_out -> addChild(node_in); // assert one output Tensor per output Node
+            // For each input node, connect the first output (ordered) of each 
+            // output node (ordered) to the next available input
+            std::set<NodePtr> connectedInputs;
+            for (const auto& node_in : elt.view()->getOrderedInputs()) {
+                if (connectedInputs.find(node_in.first) == connectedInputs.end()) {
+                    AIDGE_ASSERT(static_cast<std::size_t>(node_in.first->getNbFreeDataInputs()) >= gv->outputNodes().size(),
+                        "Sequential(): not enough free data inputs ({}) for input node {} (of type {}) to connect to all previous output nodes ({})",
+                        node_in.first->getNbFreeDataInputs(), node_in.first->name(), node_in.first->type(), gv->outputNodes().size());
+                    std::set<NodePtr> connectedOutputs;
+                    for (const auto& node_out : gv->getOrderedOutputs()) {
+                        if (connectedOutputs.find(node_out.first) == connectedOutputs.end()) {
+                            node_out.first -> addChild(node_in.first, node_out.second); // assert one output Tensor per output Node
+                            connectedOutputs.insert(node_out.first);
+                        }
+                    }
+                    connectedInputs.insert(node_in.first);
                 }
             }
             gv->add(elt.view());
@@ -58,16 +72,18 @@ std::shared_ptr<Aidge::GraphView> Aidge::Parallel(std::vector<OpArgs> inputs) {
 
 std::shared_ptr<Aidge::GraphView> Aidge::Residual(std::vector<OpArgs> inputs) {
     std::shared_ptr<GraphView> gv = Sequential(inputs);
-    assert(gv->outputNodes().size() == 1U && "Zero or more than one output Node for the GraphView, don't know which one to choose from for the residual connection");
+    AIDGE_ASSERT(gv->outputNodes().size() == 1U,
+        "Residual(): Zero or more than one output Node for the GraphView, don't know which one to choose from for the residual connection");
     std::shared_ptr<Node> lastNode = *gv->outputNodes().begin();
-    assert(gv->inputNodes().size() == 2U && "Zero or more than one input Node for the GraphView, don't know which one to choose from for the residual connection");
+    AIDGE_ASSERT(gv->inputNodes().size() == 2U,
+        "Residual(): Zero or more than one input Node for the GraphView, don't know which one to choose from for the residual connection");
     std::shared_ptr<Node> firstNode = nullptr;
     for (const std::shared_ptr<Node>& node_ptr : gv->inputNodes()) {
         if (node_ptr != lastNode) {
             firstNode = node_ptr;
         }
     }
-    assert(lastNode->getNbFreeDataInputs()>=1);
+    AIDGE_ASSERT(lastNode->getNbFreeDataInputs()>=1, "Residual(): missing a free data input for the output Node in order to connect the residual branch");
     gv->addChild(lastNode, firstNode, 0U, gk_IODefaultIndex);
     return gv;
 }
-- 
GitLab