From 7299fadf7f0165cdcca1aad5bed399f283740515 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Tue, 7 Nov 2023 13:59:38 +0000 Subject: [PATCH] [Upd] recipies use replace() instead of replaceWith() --- src/recipies/FuseBatchNorm.cpp | 7 +++---- src/recipies/FuseMulAdd.cpp | 28 ++++++++++++++-------------- src/recipies/RemoveFlatten.cpp | 6 ++---- unit_tests/graph/Test_GraphView.cpp | 18 ++++++++++-------- 4 files changed, 29 insertions(+), 30 deletions(-) diff --git a/src/recipies/FuseBatchNorm.cpp b/src/recipies/FuseBatchNorm.cpp index e5e59582a..4b2f7a811 100644 --- a/src/recipies/FuseBatchNorm.cpp +++ b/src/recipies/FuseBatchNorm.cpp @@ -116,15 +116,14 @@ void Aidge::fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes){ bias->set<float>(output, biasValue); } - auto g = std::make_shared<GraphView>(); - g->add(std::set<std::shared_ptr<Node>>({ + + GraphView::replace(std::set<std::shared_ptr<Node>>({ batchnorm, batchnorm->input(1).first, batchnorm->input(2).first, batchnorm->input(3).first, batchnorm->input(4).first - })); - g->replaceWith({}); + }), {}); } diff --git a/src/recipies/FuseMulAdd.cpp b/src/recipies/FuseMulAdd.cpp index 1de79890f..3c895c4b1 100644 --- a/src/recipies/FuseMulAdd.cpp +++ b/src/recipies/FuseMulAdd.cpp @@ -47,8 +47,11 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ // Step 1 : Create FC // Fetch the output dimension throught the bias size - auto producer_add_bias = add->input(1); - Tensor& bias_tensor = (producer_add_bias.first)->getOperator()->output(0); + std::shared_ptr<Node> bias = (add->getParent(1)) ? add->getParent(1)->cloneSharedOperators() : nullptr; + + Tensor& bias_tensor = bias->getOperator()->output(0); + + std::shared_ptr<Node> weight = (matmul->getParent(1)) ? matmul->getParent(1)->cloneSharedOperators() : nullptr; // Instanciate FC //std::shared_ptr<Node> fc = FC(dim[0], false, "Fc"); @@ -56,25 +59,22 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ // Step 2 : Branch existing producers & create the others // link weights & bias - if (matmul->getParent(1)==nullptr) { - matmul->getParent(0)->addChild(fc, 0, 1); - printf("MatMul out[1] == nullptr !\n"); - } else { - printf("MatMul out[1] != nullptr !\n"); - if (matmul->getParent(0)!=nullptr) - matmul->getParent(0)->addChild(fc, 0, 0); - matmul->input(1).first->addChild(fc, 0, 1); + if (weight) { + weight->addChild(fc, 0, 1); + } + if (bias) { + bias->addChild(fc, 0, 2); } - (producer_add_bias.first)->addChild(fc,0,2); // Step 3 : Update all graphviews that contains at least one node to replace // Case 1 : If all nodes are in a graph view : delete old nodes & branch input & output // Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview // Maybe create a central mechanism to update automatically all graph views rather than each node have graphview presence memory ? - auto nodeToReplace = std::make_shared<GraphView>(); - nodeToReplace->add(nodes, false); - nodeToReplace->replaceWith({fc}); + // auto nodeToReplace = std::make_shared<GraphView>(); + // nodeToReplace->add(nodes, false); + // nodeToReplace->replaceWith({fc}); + GraphView::replace({matmul, add, add->getParent(1), matmul->getParent(1)}, {fc, weight, bias}); } diff --git a/src/recipies/RemoveFlatten.cpp b/src/recipies/RemoveFlatten.cpp index 9096c107b..fdfdbfd4a 100644 --- a/src/recipies/RemoveFlatten.cpp +++ b/src/recipies/RemoveFlatten.cpp @@ -30,10 +30,8 @@ namespace Aidge { flatten = element; } } - auto g = std::make_shared<GraphView>(); - // TODO : avoid using replace_with and use a remove method instead - g->add(std::set<std::shared_ptr<Node>>({flatten})); - g->replaceWith({}); + + GraphView::replace({flatten}, {}); } void removeFlatten(std::shared_ptr<GraphView> graphView){ diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index 387d34ad8..341c7cb7a 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -277,7 +277,7 @@ TEST_CASE("Graph Forward dims", "[GraphView]") { } } -TEST_CASE("[core/graph] GraphView(replaceWith)") { +TEST_CASE("[core/graph] GraphView(replaceWith)", "[replaceWith]") { SECTION("replace small pattern") { // create original graph std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); @@ -298,19 +298,21 @@ TEST_CASE("[core/graph] GraphView(replaceWith)") { REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({matmulWeight, addBias, other1, other2, matmul, add})); // create graph to replace - std::shared_ptr<GraphView> nodeToReplace = std::make_shared<GraphView>(); - nodeToReplace->add({matmul, add}, false); + std::shared_ptr<GraphView> nodeToReplace = std::make_shared<GraphView>("NodesToReplace"); + nodeToReplace->add({matmul, add}, true); // create replacing graph std::shared_ptr<Node> newNode = GenericOperator("FC", 1, 3, 1, "fc"); - other1->addChild(newNode); - matmulWeight->addChild(newNode, 0, 1); - addBias->addChild(newNode, 0, 2); + // other1->addChild(newNode); + auto newMatmulWeight = matmulWeight->cloneSharedOperators(); + newMatmulWeight->addChild(newNode, 0, 1); + auto newAddBias = addBias->cloneSharedOperators(); + newAddBias->addChild(newNode, 0, 2); // replace - nodeToReplace->replaceWith({newNode}); + nodeToReplace->replaceWith({newNode, newMatmulWeight, newAddBias}); - REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({matmulWeight, addBias, other1, other2, newNode})); + REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight, newAddBias, other1, other2, newNode})); } SECTION("replace with nothing") { std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); -- GitLab