From 1f2d196d6037d56bd5748d9f51a4eca9c1398140 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Tue, 28 Nov 2023 23:10:23 +0100
Subject: [PATCH] Working version: node ordering is now well defined

---
 include/aidge/graph/GraphView.hpp   |  16 ++-
 src/graph/GraphView.cpp             | 152 ++++++++++++++++++----------
 unit_tests/graph/Test_GraphView.cpp |  34 +++++--
 3 files changed, 137 insertions(+), 65 deletions(-)

diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp
index 5462935be..df8352bcf 100644
--- a/include/aidge/graph/GraphView.hpp
+++ b/include/aidge/graph/GraphView.hpp
@@ -35,6 +35,9 @@ private:
     /// @brief Name of the graphview
     std::string mName;
 
+    /// @brief GraphView root node
+    NodePtr mRootNode;
+
     /// @brief Set of nodes included in the GraphView
     std::set<NodePtr> mNodes;
 
@@ -99,6 +102,10 @@ public:
         return mNodes.find(nodePtr) != mNodes.end();
     }
 
+    NodePtr getRootNode() {
+        return mRootNode;
+    }
+
 ///////////////////////////////////////////////////////
 //        TENSOR MANAGEMENT
 ///////////////////////////////////////////////////////
@@ -263,8 +270,9 @@ public:
      * @brief Include a set of Nodes to the current GraphView object.
      * @param otherNodes
      * @param includeLearnableParam
+     * @return true if graph ordering is unique (meaning inputs/outputs order is well defined).
      */
-    void add(std::set<NodePtr> otherNodes,
+    bool add(std::set<NodePtr> otherNodes,
              bool includeLearnableParam = true);
 
     /**
@@ -272,16 +280,18 @@ public:
      * The second element in the otherNodes pair is the start node.
      * @param otherNodes
      * @param includeLearnableParam
+     * @return true if graph ordering is unique (meaning inputs/outputs order is well defined).
      */
-    void add(std::pair<NodePtr, std::set<NodePtr>> otherNodes,
+    bool add(std::pair<NodePtr, std::set<NodePtr>> otherNodes,
              bool includeLearnableParam = true);
 
     /**
      * @brief Include every Node inside another GraphView to the current
      * GraphView.
      * @param other_graph GraphView containing the Nodes to include.
+     * @return true if graph ordering is unique (meaning inputs/outputs order is well defined).
      */
-    void add(std::shared_ptr<GraphView> otherGraph);
+    bool add(std::shared_ptr<GraphView> otherGraph);
 
     /**
      * @brief Include a Node in the current GraphView and link it to another
diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp
index 499cdcf2c..f3095702b 100644
--- a/src/graph/GraphView.cpp
+++ b/src/graph/GraphView.cpp
@@ -75,15 +75,22 @@ void Aidge::GraphView::save(std::string path, bool verbose) const {
                 : node_ptr->name();
         namePtrTable[node_ptr] =
             (currentType + "_" + std::to_string(typeCounter[currentType]));
-        std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(),
-                    givenName.c_str());
+
+        if (node_ptr == mRootNode) {
+          std::fprintf(fp, "%s(%s):::rootCls\n", namePtrTable[node_ptr].c_str(),
+                      givenName.c_str());
+        }
+        else {
+          std::fprintf(fp, "%s(%s)\n", namePtrTable[node_ptr].c_str(),
+                      givenName.c_str());
+        }
     }
     // Write every link
     for (const std::shared_ptr<Node> &node_ptr : mNodes) {
       IOIndex_t outputIdx = 0;
       for (auto childs : node_ptr->getOrderedChildren()) {
         for (auto child : childs) {
-          if (child) {
+          if (child != nullptr && mNodes.find(child) != mNodes.end()) {
             IOIndex_t inputIdx = 0;
             for (auto pa_ptr : child->getParents()) {
               if (pa_ptr == node_ptr) {
@@ -116,6 +123,7 @@ void Aidge::GraphView::save(std::string path, bool verbose) const {
 
     std::fprintf(fp, "classDef inputCls fill:#afa\n");
     std::fprintf(fp, "classDef outputCls fill:#ffa\n");
+    std::fprintf(fp, "classDef rootCls stroke:#f00\n");
 
     if (verbose) {
       for (const auto &c : typeCounter) {
@@ -382,6 +390,11 @@ void Aidge::GraphView::setInputId(Aidge::IOIndex_t /*inID*/,
 }
 
 void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnableParam) {
+  // first node to be added to the graph is the root node by default
+  if (mRootNode == nullptr) {
+    mRootNode = node;
+  }
+
   // add to the GraphView nodes
   node->addView(shared_from_this());
   mNodes.insert(node);
@@ -407,80 +420,117 @@ void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnablePara
   }
 }
 
-void Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) {
+bool Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool includeLearnableParam) {
+  if (otherNodes.empty()) {
+    return true;
+  }
+
+  bool orderUnicity = true;
+
   // List only the nodes that are not already present in current graph
   std::set<NodePtr> nodesToAdd;
   std::set_difference(otherNodes.begin(), otherNodes.end(), mNodes.begin(), mNodes.end(), std::inserter(nodesToAdd, nodesToAdd.begin()));
 
-  do {
-    std::set<NodePtr> nextNodesToAdd;
-
-    // Find nodes that are direct parent of current GraphView and add them first
-    // such that the obtained GraphView inputs list will be the same, regardless 
-    // of the evaluation order of those nodes
-    // (i.e. one of their child is in current GraphView)
-    for (auto it = nodesToAdd.begin(); it != nodesToAdd.end(); ++it) {
-      for (auto child : (*it)->getChildren()) {
-        if (mNodes.find(child) != mNodes.end()) {
-          nextNodesToAdd.insert(*it);
-          it = nodesToAdd.erase(it);
+  // List the nodes to rank, initially all the nodes in the GraphView
+  std::set<NodePtr> nodesToRank(mNodes);
+  nodesToRank.insert(nodesToAdd.begin(), nodesToAdd.end());
+  std::vector<NodePtr> rankedNodesToAdd;
+
+  if (mRootNode == nullptr) {
+    std::set<NodePtr> noParentNodes;
+
+    // If no root node is defined, check nodes without parents
+    for (auto node : nodesToRank) {
+      bool noParent = true;
+      for (auto parent : node->getParents()) {
+        if (parent != nullptr && nodesToRank.find(parent) != nodesToRank.end()) {
+          noParent = false;
           break;
         }
       }
-      if (it == nodesToAdd.end()) {
-        break;
+
+      if (noParent) {
+        noParentNodes.insert(node);
       }
     }
 
-    // If there is no more parent, find nodes that are direct children of current GraphView,
-    // such that the obtained GraphView outputs list will be the same, regardless 
-    // of the evaluation order of those nodes
-    // (i.e. one of their parent is in current GraphView)
-    // TODO: this might be done simultaneously with direct parents, by removing
-    // the empty() condition, but there might be edge cases that may change
-    // the resulting inputs/outputs order depending on evaluation order (???)
-    if (nextNodesToAdd.empty()) {
-      for (auto it = nodesToAdd.begin(); it != nodesToAdd.end(); ++it) {
-        for (auto parent : (*it)->getParents()) {
-          if (mNodes.find(parent) != mNodes.end()) {
-            nextNodesToAdd.insert(*it);
-            it = nodesToAdd.erase(it);
-            break;
+    // Take the first one found (this is an arbitrary choice)
+    mRootNode = *noParentNodes.begin();
+
+    if (noParentNodes.size() > 1) {
+      // If there is more than one, order unicity cannot be garanteed!
+      orderUnicity = false;
+    }
+
+    rankedNodesToAdd.push_back(mRootNode);
+  }
+
+  nodesToRank.erase(mRootNode);
+  std::vector<NodePtr> rankedNodes;
+  rankedNodes.push_back(mRootNode);
+
+  for (size_t curNodeIdx = 0; curNodeIdx < rankedNodes.size(); ++curNodeIdx) {
+    NodePtr curNode = rankedNodes[curNodeIdx];
+
+    for (auto childs : curNode->getOrderedChildren()) {
+      for (auto child : childs) {
+        if (nodesToRank.find(child) != nodesToRank.end()) {
+          rankedNodes.push_back(child);
+          nodesToRank.erase(child);
+
+          if (nodesToAdd.find(child) != nodesToAdd.end()) {
+            rankedNodesToAdd.push_back(child);
+            nodesToAdd.erase(child);
           }
         }
-        if (it == nodesToAdd.end()) {
-          break;
-        }
       }
     }
 
-    // If no node if found, there might be remaining nodes that form an independant sub-graph
-    // In this case, additionnal inputs/outputs will be added at the end of
-    // the GraphView inputs/outputs list, in no particular order.
-    // TODO: we might try to preserve the initial inputs/ouputs relative order of those nodes
-    // if they actually comes from a GraphView, but I think that would be a far-fetched expectation
-    // from the users...
-    if (nextNodesToAdd.empty()) {
-      nodesToAdd.swap(nextNodesToAdd);
+    for (auto parent : curNode->getParents()) {
+      if (nodesToRank.find(parent) != nodesToRank.end()) {
+        rankedNodes.push_back(parent);
+        nodesToRank.erase(parent);
+
+        if (nodesToAdd.find(parent) != nodesToAdd.end()) {
+          rankedNodesToAdd.push_back(parent);
+          nodesToAdd.erase(parent);
+        }
+      }
     }
+  }
 
-    // Add selected nodes in the current GraphView, in no particular order
-    for (auto node_ptr : nextNodesToAdd) {
-      add(node_ptr, includeLearnableParam);
+  if (!nodesToAdd.empty()) {
+    // There are remaining nodes without path to the root node
+    orderUnicity = false;
+
+    while (!nodesToAdd.empty()) {
+      const auto it = nodesToAdd.begin();
+      rankedNodesToAdd.push_back(*it);
+      nodesToAdd.erase(it);
     }
   }
-  while (!nodesToAdd.empty());
+
+  for (auto node_ptr : rankedNodesToAdd) {
+    add(node_ptr, includeLearnableParam);
+  }
+
+  return orderUnicity;
 }
 
-void Aidge::GraphView::add(std::pair<NodePtr, std::set<NodePtr>> nodes, bool includeLearnableParam) {
+bool Aidge::GraphView::add(std::pair<NodePtr, std::set<NodePtr>> nodes, bool includeLearnableParam) {
   if (nodes.first != nullptr) {
+    mRootNode = nodes.first;
     add(nodes.first, includeLearnableParam);
   }
-  add(nodes.second, includeLearnableParam);
+  return add(nodes.second, includeLearnableParam);
 }
 
-void Aidge::GraphView::add(std::shared_ptr<GraphView> graph) {
-  add(graph->getNodes(), false);
+bool Aidge::GraphView::add(std::shared_ptr<GraphView> graph) {
+  if (mRootNode == nullptr) {
+    mRootNode = graph->getRootNode();
+  }
+
+  return add(graph->getNodes(), false);
 }
 
 void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode,
diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp
index a80c7a7aa..8da67e784 100644
--- a/unit_tests/graph/Test_GraphView.cpp
+++ b/unit_tests/graph/Test_GraphView.cpp
@@ -88,20 +88,32 @@ std::vector<std::pair<std::string, IOIndex_t>> nodePtrToName(const std::vector<s
 
 
 TEST_CASE("genRandomDAG") {
-    std::random_device rd;
-    const std::mt19937::result_type seed(rd());
+    const size_t nbTests = 100;
+    size_t nbUnicity = 0;
 
-    auto g1 = std::make_shared<GraphView>();
-    g1->add(genRandomDAG(seed, 10, 0.5));
-    auto g2 = std::make_shared<GraphView>();
-    g2->add(genRandomDAG(seed, 10, 0.5));
+    for (int test = 0; test < nbTests; ++test) {
+        std::random_device rd;
+        const std::mt19937::result_type seed(rd());
 
-    g1->save("./genRandomDAG1");
-    g2->save("./genRandomDAG2");
+        const auto g1 = std::make_shared<GraphView>("g1");
+        const bool unicity1 = g1->add(genRandomDAG(seed, 10, 0.5));
+        const auto g2 = std::make_shared<GraphView>("g2");
+        const bool unicity2 = g2->add(genRandomDAG(seed, 10, 0.5));
 
-    REQUIRE(nodePtrToName(g1->getNodes()) == nodePtrToName(g2->getNodes()));
-    REQUIRE(nodePtrToName(g1->getOrderedInputs()) == nodePtrToName(g2->getOrderedInputs()));
-    REQUIRE(nodePtrToName(g1->getOrderedOutputs()) == nodePtrToName(g2->getOrderedOutputs()));
+        g1->save("./genRandomDAG1");
+        g2->save("./genRandomDAG2");
+
+        REQUIRE(unicity1 == unicity2);
+
+        if (unicity1) {
+            REQUIRE(nodePtrToName(g1->getNodes()) == nodePtrToName(g2->getNodes()));
+            REQUIRE(nodePtrToName(g1->getOrderedInputs()) == nodePtrToName(g2->getOrderedInputs()));
+            REQUIRE(nodePtrToName(g1->getOrderedOutputs()) == nodePtrToName(g2->getOrderedOutputs()));
+            ++nbUnicity;
+        }
+    }
+
+    printf("nbUnicity = %zu/%zu\n", nbUnicity, nbTests);
 }
 
 TEST_CASE("[core/graph] GraphView(Constructor)") {
-- 
GitLab