From 75c3824a654ad2369f8489ab925e7bf01f391b8e Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Thu, 20 Jun 2024 21:32:16 +0200
Subject: [PATCH] Added check of input category in GraphView::replace()

---
 src/graph/GraphView.cpp             |  5 ++-
 unit_tests/graph/Test_GraphView.cpp | 63 +++++++++++++++++++++++++++++
 2 files changed, 67 insertions(+), 1 deletion(-)

diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp
index fb8a79cfe..25f1a0187 100644
--- a/src/graph/GraphView.cpp
+++ b/src/graph/GraphView.cpp
@@ -1074,7 +1074,10 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const
             // Case 2
             if ((oldOIn.size() == 1) && (inputParents[0].first)) {
                 for (std::size_t i = 0; i < newOIn.size(); ++i) {
-                    inputParents[0].first -> addChild(newOIn[i].first, inputParents[0].second, newOIn[i].second);
+                    // Only re-connect the same input category
+                    if (newOIn[i].first->inputCategory(newOIn[i].second) == oldOIn[0].first->inputCategory(oldOIn[0].second)) {
+                      inputParents[0].first -> addChild(newOIn[i].first, inputParents[0].second, newOIn[i].second);
+                    }
                 }
             } else {
                 for (std::size_t i = 0; i < oldOIn.size(); ++i) {
diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp
index 4f410a8c4..8e9f5a27e 100644
--- a/unit_tests/graph/Test_GraphView.cpp
+++ b/unit_tests/graph/Test_GraphView.cpp
@@ -552,6 +552,69 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") {
         REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myConv, other2}));
     }
 
+    SECTION("replace same input category 1") {
+        std::shared_ptr<GraphView> g = std::make_shared<GraphView>("test_graph");
+        auto otherInput = GenericOperator("Producer", {}, 1, "other_input");
+        auto other1 = GenericOperator("Other", {InputCategory::Data}, 1, "other1");
+        auto myOld = GenericOperator("myOld", {InputCategory::Data}, 1, "old");
+        auto other2 = GenericOperator("Other", {InputCategory::Data}, 1, "other2");
+        otherInput->addChild(other1);
+        other1->addChild(myOld);
+        myOld->addChild(other2);
+        g->add({other1, myOld, other2});
+
+        auto myNew =  GenericOperator("myNew", {InputCategory::Data, InputCategory::OptionalData, InputCategory::OptionalData}, 1, "new");
+
+        GraphView::replace({myOld}, {myNew});
+
+        REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myNew, other2}));
+        REQUIRE(myNew->input(0).first == other1);
+        REQUIRE(myNew->input(1).first == nullptr);
+        REQUIRE(myNew->input(2).first == nullptr);
+    }
+
+    SECTION("replace same input category 2") {
+        std::shared_ptr<GraphView> g = std::make_shared<GraphView>("test_graph");
+        auto otherInput = GenericOperator("Producer", {}, 1, "other_input");
+        auto other1 = GenericOperator("Other", {InputCategory::Data}, 1, "other1");
+        auto myOld = GenericOperator("myOld", {InputCategory::Param}, 1, "old");
+        auto other2 = GenericOperator("Other", {InputCategory::Data}, 1, "other2");
+        otherInput->addChild(other1);
+        other1->addChild(myOld, 0, 0);
+        myOld->addChild(other2);
+        g->add({other1, myOld, other2});
+
+        auto myNew =  GenericOperator("myNew", {InputCategory::Data, InputCategory::Param, InputCategory::Data}, 1, "new");
+
+        GraphView::replace({myOld}, {myNew});
+
+        REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myNew, other2}));
+        REQUIRE(myNew->input(0).first == nullptr);
+        REQUIRE(myNew->input(1).first == other1);
+        REQUIRE(myNew->input(2).first == nullptr);
+    }
+
+    SECTION("replace same input category 3") {
+        std::shared_ptr<GraphView> g = std::make_shared<GraphView>("test_graph");
+        auto otherInput = GenericOperator("Producer", {}, 1, "other_input");
+        auto other1 = GenericOperator("Other", {InputCategory::Data}, 1, "other1");
+        auto myOld = GenericOperator("myOld", {InputCategory::Data}, 1, "old");
+        auto other2 = GenericOperator("Other", {InputCategory::Data}, 1, "other2");
+        otherInput->addChild(other1);
+        other1->addChild(myOld);
+        myOld->addChild(other2);
+        g->add({other1, myOld, other2});
+
+        auto myNew =  GenericOperator("myNew", {InputCategory::Data, InputCategory::Data, InputCategory::Data}, 1, "new");
+
+        GraphView::replace({myOld}, {myNew});
+
+        REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myNew, other2}));
+        REQUIRE(myNew->input(0).first == other1);
+        REQUIRE(myNew->input(1).first == other1);
+        REQUIRE(myNew->input(2).first == other1);
+    }
+
     SECTION("Change every Nodes in a GraphView") {
         auto matmulWeight0 = GenericOperator("Producer", 0, 0, 1, "matmul_w0");
         auto addBias0 = GenericOperator("Producer", 0, 0, 1, "add_b0");
-- 
GitLab