From 7299fadf7f0165cdcca1aad5bed399f283740515 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Tue, 7 Nov 2023 13:59:38 +0000
Subject: [PATCH] [Upd] recipies use replace() instead of replaceWith()

---
 src/recipies/FuseBatchNorm.cpp      |  7 +++----
 src/recipies/FuseMulAdd.cpp         | 28 ++++++++++++++--------------
 src/recipies/RemoveFlatten.cpp      |  6 ++----
 unit_tests/graph/Test_GraphView.cpp | 18 ++++++++++--------
 4 files changed, 29 insertions(+), 30 deletions(-)

diff --git a/src/recipies/FuseBatchNorm.cpp b/src/recipies/FuseBatchNorm.cpp
index e5e59582a..4b2f7a811 100644
--- a/src/recipies/FuseBatchNorm.cpp
+++ b/src/recipies/FuseBatchNorm.cpp
@@ -116,15 +116,14 @@ void Aidge::fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes){
         bias->set<float>(output, biasValue);
 
     }
-    auto g = std::make_shared<GraphView>();
-    g->add(std::set<std::shared_ptr<Node>>({
+
+    GraphView::replace(std::set<std::shared_ptr<Node>>({
         batchnorm,
         batchnorm->input(1).first,
         batchnorm->input(2).first,
         batchnorm->input(3).first,
         batchnorm->input(4).first
-    }));
-    g->replaceWith({});
+        }), {});
 
 }
 
diff --git a/src/recipies/FuseMulAdd.cpp b/src/recipies/FuseMulAdd.cpp
index 1de79890f..3c895c4b1 100644
--- a/src/recipies/FuseMulAdd.cpp
+++ b/src/recipies/FuseMulAdd.cpp
@@ -47,8 +47,11 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
 
     // Step 1 : Create FC
     // Fetch the output dimension throught the bias size
-    auto producer_add_bias = add->input(1);
-    Tensor& bias_tensor = (producer_add_bias.first)->getOperator()->output(0);
+    std::shared_ptr<Node> bias = (add->getParent(1)) ? add->getParent(1)->cloneSharedOperators() : nullptr;
+
+    Tensor& bias_tensor = bias->getOperator()->output(0);
+
+    std::shared_ptr<Node> weight = (matmul->getParent(1)) ? matmul->getParent(1)->cloneSharedOperators() : nullptr;
 
     // Instanciate FC
     //std::shared_ptr<Node> fc = FC(dim[0], false, "Fc");
@@ -56,25 +59,22 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
 
     // Step 2 : Branch existing producers & create the others
     // link weights & bias
-    if (matmul->getParent(1)==nullptr) {
-        matmul->getParent(0)->addChild(fc, 0, 1);
-        printf("MatMul out[1] == nullptr !\n");
-    } else {
-        printf("MatMul out[1] != nullptr !\n");
-        if (matmul->getParent(0)!=nullptr)
-            matmul->getParent(0)->addChild(fc, 0, 0);
-        matmul->input(1).first->addChild(fc, 0, 1);
+    if (weight) {
+        weight->addChild(fc, 0, 1);
+    }
+    if (bias) {
+        bias->addChild(fc, 0, 2);
     }
-    (producer_add_bias.first)->addChild(fc,0,2);
 
 
     // Step 3 : Update all graphviews that contains at least one node to replace
         // Case 1 : If all nodes are in a graph view : delete old nodes & branch input & output
         // Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview
         // Maybe create a central mechanism to update automatically all graph views rather than each node have graphview presence memory ?
-    auto nodeToReplace = std::make_shared<GraphView>();
-    nodeToReplace->add(nodes, false);
-    nodeToReplace->replaceWith({fc});
+    // auto nodeToReplace = std::make_shared<GraphView>();
+    // nodeToReplace->add(nodes, false);
+    // nodeToReplace->replaceWith({fc});
+    GraphView::replace({matmul, add, add->getParent(1), matmul->getParent(1)}, {fc, weight, bias});
 
 }
 
diff --git a/src/recipies/RemoveFlatten.cpp b/src/recipies/RemoveFlatten.cpp
index 9096c107b..fdfdbfd4a 100644
--- a/src/recipies/RemoveFlatten.cpp
+++ b/src/recipies/RemoveFlatten.cpp
@@ -30,10 +30,8 @@ namespace Aidge {
                 flatten = element;
             }
         }
-        auto g = std::make_shared<GraphView>();
-        // TODO : avoid using replace_with and use a remove method instead
-        g->add(std::set<std::shared_ptr<Node>>({flatten}));
-        g->replaceWith({});
+
+        GraphView::replace({flatten}, {});
     }
 
     void removeFlatten(std::shared_ptr<GraphView> graphView){
diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp
index 387d34ad8..341c7cb7a 100644
--- a/unit_tests/graph/Test_GraphView.cpp
+++ b/unit_tests/graph/Test_GraphView.cpp
@@ -277,7 +277,7 @@ TEST_CASE("Graph Forward dims", "[GraphView]") {
     }
 }
 
-TEST_CASE("[core/graph] GraphView(replaceWith)") {
+TEST_CASE("[core/graph] GraphView(replaceWith)", "[replaceWith]") {
     SECTION("replace small pattern") {
         // create original graph
         std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph");
@@ -298,19 +298,21 @@ TEST_CASE("[core/graph] GraphView(replaceWith)") {
         REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({matmulWeight, addBias, other1, other2, matmul, add}));
 
         // create graph to replace
-        std::shared_ptr<GraphView> nodeToReplace = std::make_shared<GraphView>();
-        nodeToReplace->add({matmul, add}, false);
+        std::shared_ptr<GraphView> nodeToReplace = std::make_shared<GraphView>("NodesToReplace");
+        nodeToReplace->add({matmul, add}, true);
 
         // create replacing graph
         std::shared_ptr<Node> newNode = GenericOperator("FC", 1, 3, 1, "fc");
-        other1->addChild(newNode);
-        matmulWeight->addChild(newNode, 0, 1);
-        addBias->addChild(newNode, 0, 2);
+        // other1->addChild(newNode);
+       auto newMatmulWeight = matmulWeight->cloneSharedOperators();
+        newMatmulWeight->addChild(newNode, 0, 1);
+        auto newAddBias = addBias->cloneSharedOperators();
+        newAddBias->addChild(newNode, 0, 2);
 
         // replace
-        nodeToReplace->replaceWith({newNode});
+        nodeToReplace->replaceWith({newNode, newMatmulWeight, newAddBias});
 
-        REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({matmulWeight, addBias, other1, other2, newNode}));
+        REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight, newAddBias, other1, other2, newNode}));
     }
     SECTION("replace with nothing") {
         std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph");
-- 
GitLab