From b92b7f38a856f37991198b266754a094e708e34f Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Mon, 27 Nov 2023 17:15:10 +0100
Subject: [PATCH] Improved graph visualization

---
 include/aidge/graph/GraphView.hpp   | 13 ++++++
 src/graph/GraphView.cpp             | 60 ++++++++++++++++++++------
 unit_tests/graph/Test_GraphView.cpp | 67 +++++++++++++++++++++++------
 3 files changed, 116 insertions(+), 24 deletions(-)

diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp
index 62f6ac11c..5462935be 100644
--- a/include/aidge/graph/GraphView.hpp
+++ b/include/aidge/graph/GraphView.hpp
@@ -133,6 +133,9 @@ public:
     void setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs);
     void setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs);
 
+    inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedInputs() { return mInputNodes; };
+    inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedOutputs() { return mOutputNodes; };
+
     /**
      * @brief List outside data input connections of the GraphView.
      * Data inputs exclude inputs expecting parameters (weights or bias).
@@ -255,6 +258,7 @@ public:
      * in the GraphView automatically. Default: true.
      */
     void add(NodePtr otherNode, bool includeLearnableParam = true);
+
     /**
      * @brief Include a set of Nodes to the current GraphView object.
      * @param otherNodes
@@ -263,6 +267,15 @@ public:
     void add(std::set<NodePtr> otherNodes,
              bool includeLearnableParam = true);
 
+    /**
+     * @brief Include a set of Nodes to the current GraphView object.
+     * The second element in the otherNodes pair is the start node.
+     * @param otherNodes
+     * @param includeLearnableParam
+     */
+    void add(std::pair<NodePtr, std::set<NodePtr>> otherNodes,
+             bool includeLearnableParam = true);
+
     /**
      * @brief Include every Node inside another GraphView to the current
      * GraphView.
diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp
index 2714484eb..7b21cc889 100644
--- a/src/graph/GraphView.cpp
+++ b/src/graph/GraphView.cpp
@@ -79,23 +79,48 @@ void Aidge::GraphView::save(std::string path, bool verbose) const {
                     givenName.c_str());
     }
     // Write every link
-    std::size_t emptyInputCounter = 0;
     for (const std::shared_ptr<Node> &node_ptr : mNodes) {
-        for (const std::shared_ptr<Node> &pa_ptr : node_ptr->getParents()) {
-        if ((pa_ptr == nullptr) || !inView(pa_ptr)) {
-            std::fprintf(fp, "input%zu((in - %zu))-->%s\n", emptyInputCounter,
-                        emptyInputCounter, namePtrTable[node_ptr].c_str());
-            ++emptyInputCounter;
-        } else {
-            std::fprintf(fp, "%s-->%s\n", namePtrTable[pa_ptr].c_str(),
-                        namePtrTable[node_ptr].c_str());
-        }
+      IOIndex_t outputIdx = 0;
+      for (auto childs : node_ptr->getOrderedChildren()) {
+        for (auto child : childs) {
+          if (child) {
+            IOIndex_t inputIdx = 0;
+            for (auto pa_ptr : child->getParents()) {
+              if (pa_ptr == node_ptr) {
+                std::fprintf(fp, "%s-->|%u..%u|%s\n", namePtrTable[node_ptr].c_str(),
+                            outputIdx, inputIdx, namePtrTable[child].c_str());
+                break;
+              }
+              ++inputIdx;
+            }
+          }
         }
+        ++outputIdx;
+      }
+    }
+
+    size_t inputIdx = 0;
+    for (auto input : mInputNodes) {
+      std::fprintf(fp, "input%lu((in#%lu)):::inputCls-->|..%u|%s\n", inputIdx, inputIdx,
+                  input.second, namePtrTable[input.first].c_str());
+      ++inputIdx;
     }
+
+    size_t outputIdx = 0;
+    for (auto output : mOutputNodes) {
+      std::fprintf(fp, "%s-->|%u..|output%lu((out#%lu)):::outputCls\n",
+                   namePtrTable[output.first].c_str(), output.second,
+                   outputIdx, outputIdx);
+      ++outputIdx;
+    }
+
+    std::fprintf(fp, "classDef inputCls fill:#afa\n");
+    std::fprintf(fp, "classDef outputCls fill:#ffa\n");
+
     if (verbose) {
-        for (const auto &c : typeCounter) {
+      for (const auto &c : typeCounter) {
         std::printf("%s - %zu\n", c.first.c_str(), c.second);
-        }
+      }
     }
 
     std::fprintf(fp, "\n");
@@ -447,6 +472,13 @@ void Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool incl
   while (!nodesToAdd.empty());
 }
 
+void Aidge::GraphView::add(std::pair<NodePtr, std::set<NodePtr>> nodes, bool includeLearnableParam) {
+  if (nodes.first != nullptr) {
+    add(nodes.first, includeLearnableParam);
+  }
+  add(nodes.second, includeLearnableParam);
+}
+
 void Aidge::GraphView::add(std::shared_ptr<GraphView> graph) {
   add(graph->getNodes(), false);
 }
@@ -834,6 +866,7 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) {
       const auto val = std::make_pair(newNode, inputIdx);
       if (std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end()) {
         newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val);
+        newInputsInsertionPoint = std::next(newInputsInsertionPoint);
       }
     }
     ++inputIdx;
@@ -902,6 +935,7 @@ void Aidge::GraphView::updateInputsOutputsNew(std::shared_ptr<Node> newNode) {
       // Output may be already be present (see addChild() with a node already in GraphView)
       if (std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end()) {
         newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val);
+        newOutputsInsertionPoint = std::next(newOutputsInsertionPoint);
       }
     }
     ++outputIdx;
@@ -940,6 +974,7 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo
             const auto val = std::make_pair(ch_ptr, inputIdx);
             AIDGE_INTERNAL_ASSERT(std::find(mInputNodes.begin(), mInputNodes.end(), val) == mInputNodes.end());
             newInputsInsertionPoint = mInputNodes.insert(newInputsInsertionPoint, val);
+            newInputsInsertionPoint = std::next(newInputsInsertionPoint);
           }
           ++inputIdx;
         }
@@ -986,6 +1021,7 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo
         const auto val = std::make_pair(parent, outputIdx);
         AIDGE_INTERNAL_ASSERT(std::find(mOutputNodes.begin(), mOutputNodes.end(), val) == mOutputNodes.end());
         newOutputsInsertionPoint = mOutputNodes.insert(newOutputsInsertionPoint, val);
+        newOutputsInsertionPoint = std::next(newOutputsInsertionPoint);
       }
       ++outputIdx;
     }
diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp
index 75f9a47fe..a80c7a7aa 100644
--- a/unit_tests/graph/Test_GraphView.cpp
+++ b/unit_tests/graph/Test_GraphView.cpp
@@ -15,6 +15,8 @@
 #include <set>
 #include <string>
 #include <random>
+#include <algorithm>
+#include <utility>
 
 #include <catch2/catch_test_macros.hpp>
 
@@ -27,16 +29,28 @@
 
 using namespace Aidge;
 
-std::set<NodePtr> genRandomDAG(size_t nbNodes, float density = 0.5, size_t maxIn = 5, float avgIn = 1.5, size_t maxOut = 2, float avgOut = 1.1) {
-    std::random_device rd;
-    std::mt19937 gen(rd());
-    std::binomial_distribution<> dIn(maxIn, avgIn/maxIn);
-    std::binomial_distribution<> dOut(maxOut, avgOut/maxOut);
+std::pair<NodePtr, std::set<NodePtr>> genRandomDAG(std::mt19937::result_type seed, size_t nbNodes, float density = 0.5, size_t maxIn = 5, float avgIn = 1.5, size_t maxOut = 2, float avgOut = 1.1) {
+    std::mt19937 gen(seed);
+    std::binomial_distribution<> dIn(maxIn - 1, avgIn/maxIn);
+    std::binomial_distribution<> dOut(maxOut - 1, avgOut/maxOut);
     std::binomial_distribution<> dLink(1, density);
 
-    std::vector<NodePtr> nodes;
+    std::vector<std::pair<int, int>> nbIOs;
     for (size_t i = 0; i < nbNodes; ++i) {
-        nodes.push_back(GenericOperator("Fictive", dIn(gen), dIn(gen), dOut(gen)));
+        const auto nbIn = 1 + dIn(gen);
+        nbIOs.push_back(std::make_pair(nbIn, 1 + dOut(gen)));
+    }
+
+    std::vector<int> nodesSeq(nbNodes);
+    std::iota(nodesSeq.begin(), nodesSeq.end(), 0);
+    // Don't use gen or seed here, must be different each time!
+    std::shuffle(nodesSeq.begin(), nodesSeq.end(), std::default_random_engine(std::random_device{}()));
+
+    std::vector<NodePtr> nodes(nbNodes, nullptr);
+    for (auto idx : nodesSeq) {
+        const std::string type = "Fictive";
+        const std::string name = type + std::to_string(idx);
+        nodes[idx] = GenericOperator(type.c_str(), nbIOs[idx].first, nbIOs[idx].first, nbIOs[idx].second, name.c_str());
     }
 
     for (size_t i = 0; i < nbNodes; ++i) {
@@ -45,20 +59,49 @@ std::set<NodePtr> genRandomDAG(size_t nbNodes, float density = 0.5, size_t maxIn
                 for (size_t inId = 0; inId < nodes[j]->nbInputs(); ++inId) {
                     if (dLink(gen)) {
                         nodes[i]->addChild(nodes[j], outId, inId);
+                        break;
                     }
                 }
             }
         }
     }
-    return std::set<NodePtr>(nodes.begin(), nodes.end());
+    return std::make_pair(nodes[0], std::set<NodePtr>(nodes.begin(), nodes.end()));
+}
+
+std::set<std::string> nodePtrToName(const std::set<NodePtr>& nodes) {
+    std::set<std::string> nodesName;
+    std::transform(nodes.begin(), nodes.end(), std::inserter(nodesName, nodesName.begin()),
+        [](const NodePtr& node) {
+            return node->name();
+        });
+    return nodesName;
+}
+
+std::vector<std::pair<std::string, IOIndex_t>> nodePtrToName(const std::vector<std::pair<NodePtr, IOIndex_t>>& nodes) {
+    std::vector<std::pair<std::string, IOIndex_t>> nodesName;
+    std::transform(nodes.begin(), nodes.end(), std::back_inserter(nodesName),
+        [](const std::pair<NodePtr, IOIndex_t>& node) {
+            return std::make_pair(node.first->name(), node.second);
+        });
+    return nodesName;
 }
 
 
 TEST_CASE("genRandomDAG") {
-    auto g = std::make_shared<GraphView>();
-    g->add(genRandomDAG(10));
-    REQUIRE(g->getNodes().size() == 10);
-    g->save("./genRandomDAG");
+    std::random_device rd;
+    const std::mt19937::result_type seed(rd());
+
+    auto g1 = std::make_shared<GraphView>();
+    g1->add(genRandomDAG(seed, 10, 0.5));
+    auto g2 = std::make_shared<GraphView>();
+    g2->add(genRandomDAG(seed, 10, 0.5));
+
+    g1->save("./genRandomDAG1");
+    g2->save("./genRandomDAG2");
+
+    REQUIRE(nodePtrToName(g1->getNodes()) == nodePtrToName(g2->getNodes()));
+    REQUIRE(nodePtrToName(g1->getOrderedInputs()) == nodePtrToName(g2->getOrderedInputs()));
+    REQUIRE(nodePtrToName(g1->getOrderedOutputs()) == nodePtrToName(g2->getOrderedOutputs()));
 }
 
 TEST_CASE("[core/graph] GraphView(Constructor)") {
-- 
GitLab