diff --git a/python_binding/recipes/pybind_Recipes.cpp b/python_binding/recipes/pybind_Recipes.cpp index 1c04a320d85a833cc3c0b666390edc7a8648214b..b68dfd035921a1dce4d12b9071a8df194e2ffdd5 100644 --- a/python_binding/recipes/pybind_Recipes.cpp +++ b/python_binding/recipes/pybind_Recipes.cpp @@ -21,7 +21,7 @@ namespace py = pybind11; namespace Aidge { -void init_Recipes(py::module &m) +void init_Recipes(py::module &m) { @@ -71,9 +71,10 @@ void init_Recipes(py::module &m) )mydelimiter"); m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter( - Recipe to remove a flatten operator. + Recipe to remove a Flatten operator if it is followed by a FC or a MatMul. + The recipe can remove multiple Flatten operator if they are one after the other. - :param graph_view: Graph view on which we want to apply the recipe + :param graph_view: Graph view on which we want to apply the recipe. :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); diff --git a/src/graph/Matching.cpp b/src/graph/Matching.cpp index 22be1347aa7ef108f593d3aabe3ff6d75c9312b1..4a62019a7aa044ebcf2089d91f3ba097d85218e7 100644 --- a/src/graph/Matching.cpp +++ b/src/graph/Matching.cpp @@ -642,6 +642,10 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::set<MatchingRe for (const auto& output : outputs) { for (const auto& node : output) { + if (!node.first) { + continue; + } + if (newCtx.edgeRightIdx == gk_IODefaultIndex || node.second == newCtx.edgeRightIdx) { if (mGraph->inView(node.first) && !it->graph->inView(node.first)) { found = true; @@ -663,6 +667,10 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::set<MatchingRe : it->startNode->inputs(); for (const auto& input : inputs) { + if (!input.first) { + continue; + } + if (newCtx.edgeRightIdx == gk_IODefaultIndex || input.second == newCtx.edgeRightIdx) { if (mGraph->inView(input.first) && !it->graph->inView(input.first)) { found = true; @@ -741,6 +749,10 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::set<MatchingRe } for (const auto& node : output) { + if (!node.first) { + continue; + } + if ((type.empty() || node.first->type() == type) && (lambda.empty() || mLambda.at(lambda)(node.first)) && (newCtx.edgeRightIdx == gk_IODefaultIndex || node.second == newCtx.edgeRightIdx)) @@ -775,6 +787,10 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::set<MatchingRe : it->startNode->inputs(); for (const auto& input : inputs) { + if (!input.first) { + continue; + } + if ((type.empty() || input.first->type() == type) && (lambda.empty() || mLambda.at(lambda)(input.first)) && (newCtx.edgeRightIdx == gk_IODefaultIndex || input.second == newCtx.edgeRightIdx)) diff --git a/src/recipes/RemoveFlatten.cpp b/src/recipes/RemoveFlatten.cpp index 8c1bf1bcf0bf79fda275867ff6430d5a937da172..bf80ab51749953a5b72d0e01f186265fdbb72e81 100644 --- a/src/recipes/RemoveFlatten.cpp +++ b/src/recipes/RemoveFlatten.cpp @@ -17,38 +17,20 @@ //Graph Regex -#include "aidge/graphRegex/GraphRegex.hpp" +// #include "aidge/graphRegex/GraphRegex.hpp" +#include "aidge/graph/Matching.hpp" namespace Aidge { - void removeFlatten(std::shared_ptr<Node> flatten) { - GraphView::replace({flatten}, {}); - } - - void removeFlatten(std::shared_ptr<MatchSolution> solution){ - - assert(solution->at("FC").size() == 1 && "Wrong number of nodes FC to replace\n"); - assert(solution->at("Flatten").size() == 1 && "Wrong number of nodes Flatten to replace\n"); - - for (const auto& flatten : solution->at("Flatten")) { - removeFlatten(flatten); - } - } - - - void removeFlatten(std::shared_ptr<GraphView> graphView){ - - - std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); - regex->setNodeKey("Flatten","getType($) =='Flatten'"); - regex->setNodeKey("FC","getType($) =='FC'"); - regex->addQuery("Flatten->FC"); - - for (const auto& solution : regex->match(graphView)) { - removeFlatten(solution); + const auto matches = SinglePassGraphMatching(graphView).match( + "(FC|MatMul)<-(Flatten)+" + ); + + for (const auto& solution : matches) { + auto flattenNodes(solution.graph->getNodes()); + flattenNodes.erase(solution.graph->rootNode()); + GraphView::replace(flattenNodes, {}); } - - } } diff --git a/unit_tests/graphRegex/Test_GraphRegex.cpp b/unit_tests/graphRegex/Test_GraphRegex.cpp index e05e105d34a981e33cc1a0baaffa2702f1f6bbbb..68ac509e79e347106a9a132249f125ebe6e39f6a 100644 --- a/unit_tests/graphRegex/Test_GraphRegex.cpp +++ b/unit_tests/graphRegex/Test_GraphRegex.cpp @@ -175,27 +175,17 @@ TEST_CASE("GraphRegexUser") { matmul1->addChild(add1, 0, 0); b1->addChild(add1, 0, 1); - auto fc = GenericOperator("FC", 1, 0, 1, "c"); - auto fl = GenericOperator("Flatten", 1, 0, 1, "c"); - - + auto fc = GenericOperator("FC", 1, 0, 1, "fc1"); + auto fl = GenericOperator("Flatten", 1, 0, 1, "flatten0"); + add1->addChild(fl, 0, 0); + fl->addChild(fc, 0, 0); auto g = std::make_shared<GraphView>(); - g->add({w0, matmul0, b0, add0, w1, matmul1, b1, add1,fl,fc}); - - std::shared_ptr<GraphRegex> kitchenBook = std::make_shared<GraphRegex>(); - - kitchenBook->setNodeKey("Add","getType($) =='Add'"); - kitchenBook->setNodeKey("MatMul","getType($) =='MatMul'"); - kitchenBook->setNodeKey("Flatten","getType($) =='Flatten'"); - kitchenBook->setNodeKey("FC","getType($) =='FC'"); - - //kitchenBook->addQuery("MatMul->Add",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(fuseMulAdd)); - kitchenBook->addQuery("Flatten->FC",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(removeFlatten)); - - kitchenBook->appliedRecipes(g); + g->add({w0, matmul0, b0, add0, w1, matmul1, b1, add1, fl, fc}); + matMulToFC(g); + removeFlatten(g); std::set<std::shared_ptr<Node>> newNodes = g->getNodes(); - REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1,fc})); + REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1,fl,fc})); //REQUIRE(newNodes.size() == 6); diff --git a/unit_tests/recipes/Test_removeFlatten.cpp b/unit_tests/recipes/Test_removeFlatten.cpp index 24f5aa2e231b5204add1c8f87cdeb7a71175ea05..c3b4c08d98115c9f081bbbf8cb677114b66c545a 100644 --- a/unit_tests/recipes/Test_removeFlatten.cpp +++ b/unit_tests/recipes/Test_removeFlatten.cpp @@ -42,7 +42,7 @@ TEST_CASE("[cpu/recipies] RemoveFlatten", "[RemoveFlatten][recipies]") { CHECK(g->getOrderedInputs().size() == 1); CHECK(g->getOrderedInputs()[0].first == fc0); - + CHECK(fc0->getParent(0) == nullptr); CHECK(fc0->getChildren(0).size() == 1); CHECK(g->rootNode() == fc0); @@ -54,10 +54,10 @@ TEST_CASE("[cpu/recipies] RemoveFlatten", "[RemoveFlatten][recipies]") { CHECK(g->getOrderedInputs().size() == 1); CHECK(g->getOrderedInputs()[0].first == fc0); - + CHECK(g->getOrderedOutputs().size() == 1); CHECK(g->getOrderedOutputs()[0].first == fc0); - + CHECK(fc0->getParent(0) == nullptr); CHECK(fc0->getChildren(0).size() == 0); CHECK(g->rootNode() == fc0); @@ -73,7 +73,7 @@ TEST_CASE("[cpu/recipies] RemoveFlatten", "[RemoveFlatten][recipies]") { CHECK(g->getOrderedOutputs().size() == 1); CHECK(g->getOrderedOutputs()[0].first == fc1); - + CHECK(fc1->getParent(0) == fc0); CHECK(fc0->getChildren(0)[0] == fc1); CHECK(g->rootNode() == fc0); @@ -87,10 +87,10 @@ TEST_CASE("[cpu/recipies] RemoveFlatten", "[RemoveFlatten][recipies]") { removeFlatten(g); CHECK(g->getOrderedInputs().size() == 0); - + CHECK(g->getOrderedOutputs().size() == 1); CHECK(g->getOrderedOutputs()[0].first == fc0); - + CHECK(fc0->getParent(0) == prod); CHECK(fc0->getChildren(0).size() == 0);