diff --git a/python_binding/recipes/pybind_Recipes.cpp b/python_binding/recipes/pybind_Recipes.cpp index 1c04a320d85a833cc3c0b666390edc7a8648214b..b68dfd035921a1dce4d12b9071a8df194e2ffdd5 100644 --- a/python_binding/recipes/pybind_Recipes.cpp +++ b/python_binding/recipes/pybind_Recipes.cpp @@ -21,7 +21,7 @@ namespace py = pybind11; namespace Aidge { -void init_Recipes(py::module &m) +void init_Recipes(py::module &m) { @@ -71,9 +71,10 @@ void init_Recipes(py::module &m) )mydelimiter"); m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter( - Recipe to remove a flatten operator. + Recipe to remove a Flatten operator if it is followed by a FC or a MatMul. + The recipe can remove multiple Flatten operator if they are one after the other. - :param graph_view: Graph view on which we want to apply the recipe + :param graph_view: Graph view on which we want to apply the recipe. :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); diff --git a/src/recipes/RemoveFlatten.cpp b/src/recipes/RemoveFlatten.cpp index 8c1bf1bcf0bf79fda275867ff6430d5a937da172..8e59ea090c55ddb422a30acd1c7b9348ee1025c8 100644 --- a/src/recipes/RemoveFlatten.cpp +++ b/src/recipes/RemoveFlatten.cpp @@ -17,38 +17,33 @@ //Graph Regex -#include "aidge/graphRegex/GraphRegex.hpp" +// #include "aidge/graphRegex/GraphRegex.hpp" +#include "aidge/graph/Matching.hpp" namespace Aidge { - void removeFlatten(std::shared_ptr<Node> flatten) { - GraphView::replace({flatten}, {}); - } - - void removeFlatten(std::shared_ptr<MatchSolution> solution){ - assert(solution->at("FC").size() == 1 && "Wrong number of nodes FC to replace\n"); - assert(solution->at("Flatten").size() == 1 && "Wrong number of nodes Flatten to replace\n"); - - for (const auto& flatten : solution->at("Flatten")) { - removeFlatten(flatten); + void removeFlatten(const std::set<NodePtr>& solution){ + std::set<NodePtr> flattenNodes {}; + for (const auto& node : solution) { + if (node->type() == "Flatten"){ + printf("Flatten found.\n"); + flattenNodes.insert(node); + } + else if (! (node->type() == "MatMul" || node->type() == "FC")){ + AIDGE_THROW_OR_ABORT(std::runtime_error, "Node of type {} is not MatMul nor FC, an error during GraphMatching occured !", node->type()); + } } + GraphView::replace(flattenNodes, {}); } - - void removeFlatten(std::shared_ptr<GraphView> graphView){ - - - std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); - regex->setNodeKey("Flatten","getType($) =='Flatten'"); - regex->setNodeKey("FC","getType($) =='FC'"); - regex->addQuery("Flatten->FC"); + const auto matches = SinglePassGraphMatching(graphView).match( + "(FC|MatMul)<-(Flatten)+" + ); - for (const auto& solution : regex->match(graphView)) { - removeFlatten(solution); + for (const auto& solution : matches) { + removeFlatten(solution.graph->getNodes()); } - - } }