From ad28a0e83d449307d29953f70201b2badd4b93fb Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Wed, 8 Nov 2023 10:14:31 +0000
Subject: [PATCH] [Fix] replace() member function to run python tests
 successfully

---
 include/aidge/utils/Recipies.hpp    |  5 +++-
 src/graph/GraphView.cpp             | 13 ++++++----
 unit_tests/graph/Test_GraphView.cpp | 39 +++++++++++++++++++++++++++++
 3 files changed, 51 insertions(+), 6 deletions(-)

diff --git a/include/aidge/utils/Recipies.hpp b/include/aidge/utils/Recipies.hpp
index 894e56fae..c110c9cf8 100644
--- a/include/aidge/utils/Recipies.hpp
+++ b/include/aidge/utils/Recipies.hpp
@@ -12,6 +12,9 @@
 #ifndef AIDGE_CORE_UTILS_RECIPIES_H_
 #define AIDGE_CORE_UTILS_RECIPIES_H_
 
+#include <memory>
+#include <set>
+
 #include "aidge/graph/Node.hpp"
 #include "aidge/graph/GraphView.hpp"
 
@@ -47,7 +50,7 @@ void removeFlatten(std::set<std::shared_ptr<Node>> nodes);
  * @param graphView Graph view to use graph matching on, in order to apply transfomrations.
  */
 void removeFlatten(std::shared_ptr<GraphView> graphView);
- 
+
 // FUSE BN + FC || CONV -> FC || CONV
 
 /**
diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp
index 98644c5bd..fbb4b9871 100644
--- a/src/graph/GraphView.cpp
+++ b/src/graph/GraphView.cpp
@@ -619,14 +619,17 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s
     for (auto& nodePtr : nodesToClean) { nodePtr->resetConnections(true); }
 
     // copy output connections
-    for (IOIndex_t o = 0; o < firstPreviousOutputNode->nbOutputs(); ++o) {
-        auto outputPairs = copyOutputs[o];
-        for (const auto& onePair : outputPairs) {
-            newOutputNode->addChild(onePair.first, o, onePair.second);
+    if (newOutputNode) {
+        for (IOIndex_t o = 0; o < firstPreviousOutputNode->nbOutputs(); ++o) {
+            auto outputPairs = copyOutputs[o];
+            for (const auto& onePair : outputPairs) {
+                newOutputNode->addChild(onePair.first, o, onePair.second);
+            }
         }
     }
+
     // copy input connections
-    if (!newNodes.empty()) {
+    if (!newNodes.empty() && externalInput) {
         for (const auto& newInputNode : newG->inputNodes()) {
             IOIndex_t inputId = 0;
             for (const auto& input : newInputNode->inputs()) {
diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp
index 341c7cb7a..e23ee94a7 100644
--- a/unit_tests/graph/Test_GraphView.cpp
+++ b/unit_tests/graph/Test_GraphView.cpp
@@ -12,6 +12,7 @@
 #include <cassert>
 #include <map>
 #include <memory>
+#include <set>
 #include <string>
 
 #include <catch2/catch_test_macros.hpp>
@@ -419,6 +420,44 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") {
 
         REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myConv, other2}));
     }
+
+    SECTION("Change every Nodes in a GraphView") {
+        auto matmulWeight0 = GenericOperator("Producer", 0, 0, 1, "matmul_w0");
+        auto addBias0 = GenericOperator("Producer", 0, 0, 1, "add_b0");
+        auto matmul0 = GenericOperator("MatMul", 1, 2, 1, "matmul0");
+        auto add0 = GenericOperator("Add", 1, 2, 1, "add0");
+        auto matmulWeight1 = GenericOperator("Producer", 0, 0, 1, "matmul_w1");
+        auto addBias1 = GenericOperator("Producer", 0, 0, 1, "add_b1");
+        auto matmul1 = GenericOperator("MatMul", 1, 2, 1, "matmul1");
+        auto add1 = GenericOperator("Add", 1, 2, 1, "add1");
+
+        matmulWeight0 -> addChild(matmul0, 0, 1);
+        addBias0 -> addChild(add0, 0, 1);
+        matmulWeight1 -> addChild(matmul1, 0, 1);
+        addBias1 -> addChild(add1, 0, 1);
+        matmul0 -> addChild(add0, 0, 0);
+        add0 -> addChild(matmul1, 0, 0);
+        matmul1 -> addChild(add1, 0, 0);
+
+        auto g = std::make_shared<GraphView>("TestGraph");
+        g -> add({matmulWeight0, addBias0, matmulWeight1, addBias1, matmul0, add0, matmul1, add1});
+        auto newMatmulWeight0 = matmulWeight0->cloneSharedOperators();
+        auto newAddBias0 = addBias0->cloneSharedOperators();
+        auto newMatmulWeight1 = matmulWeight1->cloneSharedOperators();
+        auto newAddBias1 = addBias1->cloneSharedOperators();
+        auto fc0 = GenericOperator("FC", 1, 3, 1, "fc0");
+        auto fc1 = GenericOperator("FC", 1, 3, 1, "fc1");
+
+        newMatmulWeight0 -> addChild(fc0, 0, 1);
+        newAddBias0 -> addChild(fc0, 0, 2);
+        newMatmulWeight1 -> addChild(fc1, 0, 1);
+        newAddBias1 -> addChild(fc1, 0, 2);
+
+        GraphView::replace({matmul0, add0, matmulWeight0, addBias0}, {newMatmulWeight0, newAddBias0, fc0});
+        GraphView::replace({matmul1, add1, matmulWeight1, addBias1}, {newMatmulWeight1, newAddBias1, fc1});
+
+        REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight0, newAddBias0, newAddBias1, newMatmulWeight1, fc1, fc0}));
+    }
 }
 
 TEST_CASE("[GraphView] clone") {
-- 
GitLab