From 080743a94fb11f6331751176d7c59aa1fcf38e9f Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Fri, 27 Oct 2023 16:06:40 +0000
Subject: [PATCH] [Upd] replace() instead of replaceWith() in GraphView

---
 include/aidge/graph/GraphView.hpp   |  36 +++++++---
 src/graph/GraphView.cpp             | 101 ++++++++++++++++++++++++++++
 unit_tests/graph/Test_GraphView.cpp |  64 ++++++++++++++++++
 3 files changed, 190 insertions(+), 11 deletions(-)

diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp
index 89ba14849..404b0fd02 100644
--- a/include/aidge/graph/GraphView.hpp
+++ b/include/aidge/graph/GraphView.hpp
@@ -322,17 +322,17 @@ public:
 
     /**
      * @brief Insert a node (newParentNode) as a parent of the passed node (childNode).
-     * 
+     *
      * @param childNode Node that gets a new parent.
      * @param newParentNode Inserted Node.
      * @param childInputTensorIdx Index of the input Tensor for the childNode linked to the inserted Node output.
      * @param newParentInputTensorIdx Index of the input Tensor for the newParentNode linked to the former parent of childNode.
      * @param newParentOutputTensorIdx Index of the output Tensor for the newParentNode linked to the childNode's input Tensor.
      */
-    void insertParent(NodePtr childNode, 
-                        NodePtr newParentNode, 
-                        IOIndex_t childInputTensorIdx, 
-                        IOIndex_t newParentInputTensorIdx, 
+    void insertParent(NodePtr childNode,
+                        NodePtr newParentNode,
+                        IOIndex_t childInputTensorIdx,
+                        IOIndex_t newParentInputTensorIdx,
                         IOIndex_t newParentOutputTensorIdx);
 
     /**
@@ -342,6 +342,20 @@ public:
      * @return false
      */
     bool replaceWith(std::set<NodePtr> newNodes);
+
+    /**
+     * @brief Replace a set of Nodes in the current GraphView with a new set of Nodes if possible.
+     * Both sets should include all the necessary Producers.
+     * @details Replaced Nodes are only removed from the current GraphView. Other GraphView containing
+     * them will not be affected by the replacement. The oldNodes set should have only one input/output
+     * Node for automatic connections of newNodes set.
+     * @param oldNodes actual set of shared_ptr<Node> to replace.
+     * @param newNodes new set of shared_ptr<Node>.
+     * @return true
+     * @return false
+     */
+    bool replace(std::set<NodePtr>& oldNodes, std::set<NodePtr>& newNodes);
+
     void updateInputNodes();
     /**
      * @brief Process from zero the set of output Nodes.
@@ -379,6 +393,12 @@ public:
      */
     std::shared_ptr<GraphView> cloneCallback(NodePtr(*cloneNode)(NodePtr)) const;
 
+    /**
+     * @brief Get the sum of the number of free dataInput connection for all inputNodes of the GraphView object.
+     * @return IOIndex_t
+     */
+    IOIndex_t getNbFreeDataInputs() const;
+
 private:
 ///////////////////////////////////////////////////////
 //        TENSOR MANAGEMENT
@@ -390,12 +410,6 @@ private:
      */
     IOIndex_t getNbDataInputs() const;
 
-    /**
-     * @brief Get the sum of the number of free dataInput connection for all inputNodes of the GraphView object.
-     * @return IOIndex_t
-     */
-    IOIndex_t getNbFreeDataInputs() const;
-
     /**
      * @brief Update the set of inputNodes with a new Node, checking if it can be
      * added and removing any Node not part of mInputNode anymore.
diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp
index 8f8f51c89..9b048b126 100644
--- a/src/graph/GraphView.cpp
+++ b/src/graph/GraphView.cpp
@@ -17,6 +17,7 @@
 #include "aidge/utils/Types.h"
 #include "aidge/graph/GraphView.hpp"
 #include "aidge/data/Tensor.hpp"
+#include "aidge/utils/ErrorHandling.hpp"
 
 ///////////////////////////////////////////////////////
 //        FUNCTIONAL DESCRIPTION
@@ -594,6 +595,106 @@ bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) {
   return replacable;
 }
 
+bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidge::NodePtr>& newNodes) {
+    for (const auto& node : oldNodes) {
+        if (mNodes.find(node) == mNodes.end()) {
+            AIDGE_INTERNAL_ASSERT("GraphView asked to replace a Node it does not contain.");
+        }
+    }
+    // TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes)
+    // How to distinguish it from data input?
+    // TODO: Parameter Tensors could be identified with their dimensions
+    // TODO: Take GraphView as input parameters since new Nodes should be connected whatever.
+    // It also avoids specifying each producer since they are automatically included
+
+    auto oldG = std::make_shared<GraphView>();
+    oldG->add(oldNodes, false);
+    auto newG = std::make_shared<GraphView>();
+    newG->add(newNodes, false);
+
+    if ((oldG->inputNodes().size() != 1) || (oldG->outputNodes().size() != 1)) {
+        return false;
+    }
+    if (!(newNodes.empty()) && ((newG->inputNodes().size() != 1) ||
+                                (newG->outputNodes().size() != 1))) {
+        return false;
+    }
+
+    std::shared_ptr<Node> previousInputNode = (*(oldG->inputNodes()).begin());
+    std::shared_ptr<Node> previousOutputNode = (*(oldG->outputNodes()).begin());
+
+    // find Node to link to new input Node
+    //compute number of input for previousInputNode not in oldNodes set
+    std::size_t nbExternalInputs = 0;
+    std::shared_ptr<Node> externalInput = nullptr;
+    IOIndex_t externalInputId = gk_IODefaultIndex;
+    for (const auto& input : previousInputNode->inputs()) {
+        if (oldNodes.find(input.first) == oldNodes.end()) {
+            nbExternalInputs++;
+            externalInput = input.first;
+            externalInputId = input.second;
+        }
+    }
+    if (nbExternalInputs > 1) {
+        AIDGE_INTERNAL_ASSERT("To many input to link for oldNodes set");
+    }
+    if (previousOutputNode->nbOutputs() != 1) {
+        return false;
+    }
+
+    // find Node to replicate output connections
+    std::shared_ptr<Node> newOutputNode = newNodes.empty() ? externalInput : *(newG->outputNodes().begin());
+
+    auto copyOutputs = previousOutputNode->outputs();
+    // manage Views for newNodes
+    // only keep common views to each node for the new set
+    std::set<std::shared_ptr<GraphView>> commonGraphViews =  (*oldNodes.begin())->views();
+    for (const auto& nodePtr : oldNodes) {
+      const auto nodeView = nodePtr->views();
+      std::set<std::shared_ptr<GraphView>> intersection;
+      std::set_intersection(commonGraphViews.begin(), commonGraphViews.end(),
+                          nodeView.begin(), nodeView.end(),
+                          std::inserter(intersection, intersection.begin()));
+      commonGraphViews = intersection;
+    }
+
+    // clean Nodes to replace
+    // Do not include common Nodes to avoid cleaning Producers linked to newNodes
+    std::set<std::shared_ptr<Node>> nodesToClean;
+    std::set_difference(oldNodes.begin(), oldNodes.end(),
+                          newNodes.begin(), newNodes.end(),
+                          std::inserter(nodesToClean, nodesToClean.begin()));
+    for (auto& nodePtr : nodesToClean) { nodePtr->resetConnections(true); }
+
+    // copy output connections
+    for (IOIndex_t o = 0; o < previousOutputNode->nbOutputs(); ++o) {
+        auto outputPairs = copyOutputs[o];
+        for (const auto& onePair : outputPairs) {
+            newOutputNode->addChild(onePair.first, o, onePair.second);
+        }
+    }
+    // copy input connections
+    if (!newNodes.empty()) {
+        std::shared_ptr<Node> newInputNode = (*(newG->inputNodes()).begin());
+        if (newInputNode->getNbFreeDataInputs() > 1) {
+            return false;
+        }
+        // one non-connected input in newNodes set
+        externalInput->addChild(newInputNode, externalInputId, newInputNode->getFirstFreeDataInput());
+    }
+
+    // insert new Nodes in the right GraphViews
+    for (auto& graphPtr : commonGraphViews) {
+        graphPtr->add(newNodes, false);
+        if (newNodes.empty()) {
+            graphPtr->updateInputNodes();
+            graphPtr->updateOutputNodes();
+        }
+    }
+    return true;
+}
+
+
 void Aidge::GraphView::updateInputNodes() {
   mInputNodes.clear();
   for (const std::shared_ptr<Node>& go_ptr : mNodes) {
diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp
index 9f0143646..4390dbe11 100644
--- a/unit_tests/graph/Test_GraphView.cpp
+++ b/unit_tests/graph/Test_GraphView.cpp
@@ -332,6 +332,70 @@ TEST_CASE("[core/graph] GraphView(replaceWith)") {
     }
 }
 
+TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") {
+    SECTION("replace small pattern") {
+        // create original graph
+        std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph");
+        auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input");
+        auto matmulWeight = GenericOperator("Producer", 0, 0, 1, "matmul_w");
+        auto addBias = GenericOperator("Producer", 0, 0, 1, "add_b");
+        auto other1 = GenericOperator("Other", 1, 1, 1, "other1");
+        auto other2 = GenericOperator("Other", 1, 1, 1, "other2");
+        auto matmul = GenericOperator("MatMul", 1, 2, 1, "matmul");
+        auto add = GenericOperator("Add", 1, 2, 1, "add");
+        otherInput->addChild(other1);
+        other1->addChild(matmul);
+        matmul->addChild(add);
+        add->addChild(other2);
+        matmulWeight->addChild(matmul, 0, 1);
+        addBias->addChild(add, 0, 1);
+        g->add({other1, matmul, add, other2});
+        REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({matmulWeight, addBias, other1, other2, matmul, add}));
+
+        // create graph to replace
+        std::set<std::shared_ptr<Node>> nodeToReplace = std::set<std::shared_ptr<Node>>({matmulWeight, addBias, matmul, add});
+
+        // create replacing graph
+        std::shared_ptr<Node> myFC = GenericOperator("FC", 1, 3, 1, "fc");
+        auto newMatmulWeight = matmulWeight->cloneSharedOperators();
+        newMatmulWeight->addChild(myFC, 0, 1);
+        auto newAddBias = addBias->cloneSharedOperators();
+        newAddBias->addChild(myFC, 0, 2);
+        std::set<std::shared_ptr<Node>> newNodes = std::set<std::shared_ptr<Node>>({myFC, newMatmulWeight, newAddBias});
+
+        // replace
+        g->replace(nodeToReplace, newNodes);
+
+        REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight, newAddBias, other1, other2, myFC}));
+        REQUIRE(((myFC->getParent(0) == other1) && (myFC->getParent(1) == newMatmulWeight) && (myFC->getParent(2) == newAddBias)));
+    }
+    SECTION("replace with nothing") {
+        std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph");
+        auto r1 = GenericOperator("relu", 0, 0, 1);
+        auto r2 = GenericOperator("relu", 1, 1, 1);
+        auto r3 = GenericOperator("relu", 1, 1, 1);
+        auto r4 = GenericOperator("relu", 1, 1, 0);
+        r1->addChild(r2);
+        r2->addChild(r3);
+        r3->addChild(r4);
+        g->add({r1, r2, r3, r4});
+        auto nodesToReplace = std::set<std::shared_ptr<Node>>({r2, r3});
+        auto newNodes = std::set<std::shared_ptr<Node>>({});
+        g->replace(nodesToReplace, newNodes);
+
+        REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({r1, r4}));
+        REQUIRE((r1->output(0))[0].first == r4);
+    }
+    // SECTION("replace for tiling") {
+    //     std::shared_ptr<GraphView> g = std::make_shared<GraphView>();
+    //     auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input");
+    //     auto other1 = GenericOperator("Other", 1, 1, 1, "other1");
+    //     auto myConv = GenericOperator("Conv", 1, 1, 1, "myConv");
+    //     auto other2 = GenericOperator("Other", 1, 1, 1, "other2");
+    //     otherInput->addChild()
+    // }
+}
+
 TEST_CASE("[GraphView] clone") {
     auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
     auto conv1 = Conv(3, 32, {3, 3}, "conv1");
-- 
GitLab