From 75c3824a654ad2369f8489ab925e7bf01f391b8e Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Thu, 20 Jun 2024 21:32:16 +0200 Subject: [PATCH] Added check of input category in GraphView::replace() --- src/graph/GraphView.cpp | 5 ++- unit_tests/graph/Test_GraphView.cpp | 63 +++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index fb8a79cfe..25f1a0187 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -1074,7 +1074,10 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const // Case 2 if ((oldOIn.size() == 1) && (inputParents[0].first)) { for (std::size_t i = 0; i < newOIn.size(); ++i) { - inputParents[0].first -> addChild(newOIn[i].first, inputParents[0].second, newOIn[i].second); + // Only re-connect the same input category + if (newOIn[i].first->inputCategory(newOIn[i].second) == oldOIn[0].first->inputCategory(oldOIn[0].second)) { + inputParents[0].first -> addChild(newOIn[i].first, inputParents[0].second, newOIn[i].second); + } } } else { for (std::size_t i = 0; i < oldOIn.size(); ++i) { diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index 4f410a8c4..8e9f5a27e 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -552,6 +552,69 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") { REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myConv, other2})); } + SECTION("replace same input category 1") { + std::shared_ptr<GraphView> g = std::make_shared<GraphView>("test_graph"); + auto otherInput = GenericOperator("Producer", {}, 1, "other_input"); + auto other1 = GenericOperator("Other", {InputCategory::Data}, 1, "other1"); + auto myOld = GenericOperator("myOld", {InputCategory::Data}, 1, "old"); + auto other2 = GenericOperator("Other", {InputCategory::Data}, 1, "other2"); + otherInput->addChild(other1); + other1->addChild(myOld); + myOld->addChild(other2); + g->add({other1, myOld, other2}); + + auto myNew = GenericOperator("myNew", {InputCategory::Data, InputCategory::OptionalData, InputCategory::OptionalData}, 1, "new"); + + GraphView::replace({myOld}, {myNew}); + + REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myNew, other2})); + REQUIRE(myNew->input(0).first == other1); + REQUIRE(myNew->input(1).first == nullptr); + REQUIRE(myNew->input(2).first == nullptr); + } + + SECTION("replace same input category 2") { + std::shared_ptr<GraphView> g = std::make_shared<GraphView>("test_graph"); + auto otherInput = GenericOperator("Producer", {}, 1, "other_input"); + auto other1 = GenericOperator("Other", {InputCategory::Data}, 1, "other1"); + auto myOld = GenericOperator("myOld", {InputCategory::Param}, 1, "old"); + auto other2 = GenericOperator("Other", {InputCategory::Data}, 1, "other2"); + otherInput->addChild(other1); + other1->addChild(myOld, 0, 0); + myOld->addChild(other2); + g->add({other1, myOld, other2}); + + auto myNew = GenericOperator("myNew", {InputCategory::Data, InputCategory::Param, InputCategory::Data}, 1, "new"); + + GraphView::replace({myOld}, {myNew}); + + REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myNew, other2})); + REQUIRE(myNew->input(0).first == nullptr); + REQUIRE(myNew->input(1).first == other1); + REQUIRE(myNew->input(2).first == nullptr); + } + + SECTION("replace same input category 3") { + std::shared_ptr<GraphView> g = std::make_shared<GraphView>("test_graph"); + auto otherInput = GenericOperator("Producer", {}, 1, "other_input"); + auto other1 = GenericOperator("Other", {InputCategory::Data}, 1, "other1"); + auto myOld = GenericOperator("myOld", {InputCategory::Data}, 1, "old"); + auto other2 = GenericOperator("Other", {InputCategory::Data}, 1, "other2"); + otherInput->addChild(other1); + other1->addChild(myOld); + myOld->addChild(other2); + g->add({other1, myOld, other2}); + + auto myNew = GenericOperator("myNew", {InputCategory::Data, InputCategory::Data, InputCategory::Data}, 1, "new"); + + GraphView::replace({myOld}, {myNew}); + + REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myNew, other2})); + REQUIRE(myNew->input(0).first == other1); + REQUIRE(myNew->input(1).first == other1); + REQUIRE(myNew->input(2).first == other1); + } + SECTION("Change every Nodes in a GraphView") { auto matmulWeight0 = GenericOperator("Producer", 0, 0, 1, "matmul_w0"); auto addBias0 = GenericOperator("Producer", 0, 0, 1, "add_b0"); -- GitLab