From a551865ef34aeeabad94f274a1fb4ffe657ebe30 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Fri, 30 Aug 2024 08:48:47 +0000 Subject: [PATCH] Remove Flatten now use new graphmatching + remove Flatten beofre MatMul + remove multiple Flatten. --- python_binding/recipes/pybind_Recipes.cpp | 7 ++-- src/recipes/RemoveFlatten.cpp | 41 ++++++++++------------- 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/python_binding/recipes/pybind_Recipes.cpp b/python_binding/recipes/pybind_Recipes.cpp index 1c04a320d..b68dfd035 100644 --- a/python_binding/recipes/pybind_Recipes.cpp +++ b/python_binding/recipes/pybind_Recipes.cpp @@ -21,7 +21,7 @@ namespace py = pybind11; namespace Aidge { -void init_Recipes(py::module &m) +void init_Recipes(py::module &m) { @@ -71,9 +71,10 @@ void init_Recipes(py::module &m) )mydelimiter"); m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter( - Recipe to remove a flatten operator. + Recipe to remove a Flatten operator if it is followed by a FC or a MatMul. + The recipe can remove multiple Flatten operator if they are one after the other. - :param graph_view: Graph view on which we want to apply the recipe + :param graph_view: Graph view on which we want to apply the recipe. :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); diff --git a/src/recipes/RemoveFlatten.cpp b/src/recipes/RemoveFlatten.cpp index 8c1bf1bcf..8e59ea090 100644 --- a/src/recipes/RemoveFlatten.cpp +++ b/src/recipes/RemoveFlatten.cpp @@ -17,38 +17,33 @@ //Graph Regex -#include "aidge/graphRegex/GraphRegex.hpp" +// #include "aidge/graphRegex/GraphRegex.hpp" +#include "aidge/graph/Matching.hpp" namespace Aidge { - void removeFlatten(std::shared_ptr<Node> flatten) { - GraphView::replace({flatten}, {}); - } - - void removeFlatten(std::shared_ptr<MatchSolution> solution){ - assert(solution->at("FC").size() == 1 && "Wrong number of nodes FC to replace\n"); - assert(solution->at("Flatten").size() == 1 && "Wrong number of nodes Flatten to replace\n"); - - for (const auto& flatten : solution->at("Flatten")) { - removeFlatten(flatten); + void removeFlatten(const std::set<NodePtr>& solution){ + std::set<NodePtr> flattenNodes {}; + for (const auto& node : solution) { + if (node->type() == "Flatten"){ + printf("Flatten found.\n"); + flattenNodes.insert(node); + } + else if (! (node->type() == "MatMul" || node->type() == "FC")){ + AIDGE_THROW_OR_ABORT(std::runtime_error, "Node of type {} is not MatMul nor FC, an error during GraphMatching occured !", node->type()); + } } + GraphView::replace(flattenNodes, {}); } - - void removeFlatten(std::shared_ptr<GraphView> graphView){ - - - std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); - regex->setNodeKey("Flatten","getType($) =='Flatten'"); - regex->setNodeKey("FC","getType($) =='FC'"); - regex->addQuery("Flatten->FC"); + const auto matches = SinglePassGraphMatching(graphView).match( + "(FC|MatMul)<-(Flatten)+" + ); - for (const auto& solution : regex->match(graphView)) { - removeFlatten(solution); + for (const auto& solution : matches) { + removeFlatten(solution.graph->getNodes()); } - - } } -- GitLab