Skip to content
Snippets Groups Projects
Commit 75c3824a authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added check of input category in GraphView::replace()

parent c0c587e1
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!143Multiple refactors
Pipeline #48735 passed
...@@ -1074,7 +1074,10 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const ...@@ -1074,7 +1074,10 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const
// Case 2 // Case 2
if ((oldOIn.size() == 1) && (inputParents[0].first)) { if ((oldOIn.size() == 1) && (inputParents[0].first)) {
for (std::size_t i = 0; i < newOIn.size(); ++i) { for (std::size_t i = 0; i < newOIn.size(); ++i) {
inputParents[0].first -> addChild(newOIn[i].first, inputParents[0].second, newOIn[i].second); // Only re-connect the same input category
if (newOIn[i].first->inputCategory(newOIn[i].second) == oldOIn[0].first->inputCategory(oldOIn[0].second)) {
inputParents[0].first -> addChild(newOIn[i].first, inputParents[0].second, newOIn[i].second);
}
} }
} else { } else {
for (std::size_t i = 0; i < oldOIn.size(); ++i) { for (std::size_t i = 0; i < oldOIn.size(); ++i) {
......
...@@ -552,6 +552,69 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") { ...@@ -552,6 +552,69 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") {
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myConv, other2})); REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myConv, other2}));
} }
SECTION("replace same input category 1") {
std::shared_ptr<GraphView> g = std::make_shared<GraphView>("test_graph");
auto otherInput = GenericOperator("Producer", {}, 1, "other_input");
auto other1 = GenericOperator("Other", {InputCategory::Data}, 1, "other1");
auto myOld = GenericOperator("myOld", {InputCategory::Data}, 1, "old");
auto other2 = GenericOperator("Other", {InputCategory::Data}, 1, "other2");
otherInput->addChild(other1);
other1->addChild(myOld);
myOld->addChild(other2);
g->add({other1, myOld, other2});
auto myNew = GenericOperator("myNew", {InputCategory::Data, InputCategory::OptionalData, InputCategory::OptionalData}, 1, "new");
GraphView::replace({myOld}, {myNew});
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myNew, other2}));
REQUIRE(myNew->input(0).first == other1);
REQUIRE(myNew->input(1).first == nullptr);
REQUIRE(myNew->input(2).first == nullptr);
}
SECTION("replace same input category 2") {
std::shared_ptr<GraphView> g = std::make_shared<GraphView>("test_graph");
auto otherInput = GenericOperator("Producer", {}, 1, "other_input");
auto other1 = GenericOperator("Other", {InputCategory::Data}, 1, "other1");
auto myOld = GenericOperator("myOld", {InputCategory::Param}, 1, "old");
auto other2 = GenericOperator("Other", {InputCategory::Data}, 1, "other2");
otherInput->addChild(other1);
other1->addChild(myOld, 0, 0);
myOld->addChild(other2);
g->add({other1, myOld, other2});
auto myNew = GenericOperator("myNew", {InputCategory::Data, InputCategory::Param, InputCategory::Data}, 1, "new");
GraphView::replace({myOld}, {myNew});
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myNew, other2}));
REQUIRE(myNew->input(0).first == nullptr);
REQUIRE(myNew->input(1).first == other1);
REQUIRE(myNew->input(2).first == nullptr);
}
SECTION("replace same input category 3") {
std::shared_ptr<GraphView> g = std::make_shared<GraphView>("test_graph");
auto otherInput = GenericOperator("Producer", {}, 1, "other_input");
auto other1 = GenericOperator("Other", {InputCategory::Data}, 1, "other1");
auto myOld = GenericOperator("myOld", {InputCategory::Data}, 1, "old");
auto other2 = GenericOperator("Other", {InputCategory::Data}, 1, "other2");
otherInput->addChild(other1);
other1->addChild(myOld);
myOld->addChild(other2);
g->add({other1, myOld, other2});
auto myNew = GenericOperator("myNew", {InputCategory::Data, InputCategory::Data, InputCategory::Data}, 1, "new");
GraphView::replace({myOld}, {myNew});
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myNew, other2}));
REQUIRE(myNew->input(0).first == other1);
REQUIRE(myNew->input(1).first == other1);
REQUIRE(myNew->input(2).first == other1);
}
SECTION("Change every Nodes in a GraphView") { SECTION("Change every Nodes in a GraphView") {
auto matmulWeight0 = GenericOperator("Producer", 0, 0, 1, "matmul_w0"); auto matmulWeight0 = GenericOperator("Producer", 0, 0, 1, "matmul_w0");
auto addBias0 = GenericOperator("Producer", 0, 0, 1, "add_b0"); auto addBias0 = GenericOperator("Producer", 0, 0, 1, "add_b0");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment