From ad28a0e83d449307d29953f70201b2badd4b93fb Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Wed, 8 Nov 2023 10:14:31 +0000 Subject: [PATCH] [Fix] replace() member function to run python tests successfully --- include/aidge/utils/Recipies.hpp | 5 +++- src/graph/GraphView.cpp | 13 ++++++---- unit_tests/graph/Test_GraphView.cpp | 39 +++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 6 deletions(-) diff --git a/include/aidge/utils/Recipies.hpp b/include/aidge/utils/Recipies.hpp index 894e56fae..c110c9cf8 100644 --- a/include/aidge/utils/Recipies.hpp +++ b/include/aidge/utils/Recipies.hpp @@ -12,6 +12,9 @@ #ifndef AIDGE_CORE_UTILS_RECIPIES_H_ #define AIDGE_CORE_UTILS_RECIPIES_H_ +#include <memory> +#include <set> + #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" @@ -47,7 +50,7 @@ void removeFlatten(std::set<std::shared_ptr<Node>> nodes); * @param graphView Graph view to use graph matching on, in order to apply transfomrations. */ void removeFlatten(std::shared_ptr<GraphView> graphView); - + // FUSE BN + FC || CONV -> FC || CONV /** diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 98644c5bd..fbb4b9871 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -619,14 +619,17 @@ bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const s for (auto& nodePtr : nodesToClean) { nodePtr->resetConnections(true); } // copy output connections - for (IOIndex_t o = 0; o < firstPreviousOutputNode->nbOutputs(); ++o) { - auto outputPairs = copyOutputs[o]; - for (const auto& onePair : outputPairs) { - newOutputNode->addChild(onePair.first, o, onePair.second); + if (newOutputNode) { + for (IOIndex_t o = 0; o < firstPreviousOutputNode->nbOutputs(); ++o) { + auto outputPairs = copyOutputs[o]; + for (const auto& onePair : outputPairs) { + newOutputNode->addChild(onePair.first, o, onePair.second); + } } } + // copy input connections - if (!newNodes.empty()) { + if (!newNodes.empty() && externalInput) { for (const auto& newInputNode : newG->inputNodes()) { IOIndex_t inputId = 0; for (const auto& input : newInputNode->inputs()) { diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index 341c7cb7a..e23ee94a7 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -12,6 +12,7 @@ #include <cassert> #include <map> #include <memory> +#include <set> #include <string> #include <catch2/catch_test_macros.hpp> @@ -419,6 +420,44 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") { REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myConv, other2})); } + + SECTION("Change every Nodes in a GraphView") { + auto matmulWeight0 = GenericOperator("Producer", 0, 0, 1, "matmul_w0"); + auto addBias0 = GenericOperator("Producer", 0, 0, 1, "add_b0"); + auto matmul0 = GenericOperator("MatMul", 1, 2, 1, "matmul0"); + auto add0 = GenericOperator("Add", 1, 2, 1, "add0"); + auto matmulWeight1 = GenericOperator("Producer", 0, 0, 1, "matmul_w1"); + auto addBias1 = GenericOperator("Producer", 0, 0, 1, "add_b1"); + auto matmul1 = GenericOperator("MatMul", 1, 2, 1, "matmul1"); + auto add1 = GenericOperator("Add", 1, 2, 1, "add1"); + + matmulWeight0 -> addChild(matmul0, 0, 1); + addBias0 -> addChild(add0, 0, 1); + matmulWeight1 -> addChild(matmul1, 0, 1); + addBias1 -> addChild(add1, 0, 1); + matmul0 -> addChild(add0, 0, 0); + add0 -> addChild(matmul1, 0, 0); + matmul1 -> addChild(add1, 0, 0); + + auto g = std::make_shared<GraphView>("TestGraph"); + g -> add({matmulWeight0, addBias0, matmulWeight1, addBias1, matmul0, add0, matmul1, add1}); + auto newMatmulWeight0 = matmulWeight0->cloneSharedOperators(); + auto newAddBias0 = addBias0->cloneSharedOperators(); + auto newMatmulWeight1 = matmulWeight1->cloneSharedOperators(); + auto newAddBias1 = addBias1->cloneSharedOperators(); + auto fc0 = GenericOperator("FC", 1, 3, 1, "fc0"); + auto fc1 = GenericOperator("FC", 1, 3, 1, "fc1"); + + newMatmulWeight0 -> addChild(fc0, 0, 1); + newAddBias0 -> addChild(fc0, 0, 2); + newMatmulWeight1 -> addChild(fc1, 0, 1); + newAddBias1 -> addChild(fc1, 0, 2); + + GraphView::replace({matmul0, add0, matmulWeight0, addBias0}, {newMatmulWeight0, newAddBias0, fc0}); + GraphView::replace({matmul1, add1, matmulWeight1, addBias1}, {newMatmulWeight1, newAddBias1, fc1}); + + REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight0, newAddBias0, newAddBias1, newMatmulWeight1, fc1, fc0})); + } } TEST_CASE("[GraphView] clone") { -- GitLab