From 19a41f8e0bcb53c013c081e027f4ab511cf6c7a9 Mon Sep 17 00:00:00 2001 From: vl241552 <vincent.lorrain@cea.fr> Date: Thu, 9 Nov 2023 14:01:42 +0000 Subject: [PATCH] removeFlatten recipies and fix pybind --- include/aidge/utils/Recipies.hpp | 6 +- python_binding/recipies/pybind_Recipies.cpp | 44 +++++++------- src/recipies/RemoveFlatten.cpp | 64 +++++++++++++++------ 3 files changed, 74 insertions(+), 40 deletions(-) diff --git a/include/aidge/utils/Recipies.hpp b/include/aidge/utils/Recipies.hpp index 4b21a4f59..3236e7bf3 100644 --- a/include/aidge/utils/Recipies.hpp +++ b/include/aidge/utils/Recipies.hpp @@ -50,7 +50,11 @@ void fuseMulAdd(std::shared_ptr<GraphView> graphView); * * @param nodes Strict set of Node to merge. */ -void removeFlatten(std::set<std::shared_ptr<Node>> nodes); +void removeFlatten(std::shared_ptr<Node> flatten); + + +void removeFlatten(std::shared_ptr<MatchSolution> solution); + /** * @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node. * diff --git a/python_binding/recipies/pybind_Recipies.cpp b/python_binding/recipies/pybind_Recipies.cpp index 93c131ef7..87abf3207 100644 --- a/python_binding/recipies/pybind_Recipies.cpp +++ b/python_binding/recipies/pybind_Recipies.cpp @@ -28,12 +28,13 @@ void init_Recipies(py::module &m) { :param graph_view: Graph view on which we want to apply the recipie :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); - m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( - Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. + + // m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( + // Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. - :param nodes: The MatMul and Add nodes to fuse. - :type nodes: list of :py:class:`aidge_core.Node` - )mydelimiter"); + // :param nodes: The MatMul and Add nodes to fuse. + // :type nodes: list of :py:class:`aidge_core.Node` + // )mydelimiter"); m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter( Recipie to remove a flatten operator. @@ -41,18 +42,20 @@ void init_Recipies(py::module &m) { :param graph_view: Graph view on which we want to apply the recipie :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); - m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter( - Recipie to remove a flatten operator. - :param nodes: The flatten operator to remove. - :type nodes: list of :py:class:`aidge_core.Node` - )mydelimiter"); - m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( - Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. + // m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter( + // Recipie to remove a flatten operator. - :param nodes: The MatMul and Add nodes to fuse. - :type nodes: list of :py:class:`aidge_core.Node` - )mydelimiter"); + // :param nodes: The flatten operator to remove. + // :type nodes: list of :py:class:`aidge_core.Node` + // )mydelimiter"); + + // m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( + // Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. + + // :param nodes: The MatMul and Add nodes to fuse. + // :type nodes: list of :py:class:`aidge_core.Node` + // )mydelimiter"); m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter( Recipie to remove a flatten operator. @@ -60,11 +63,12 @@ void init_Recipies(py::module &m) { :param graph_view: Graph view on which we want to apply the recipie :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); - m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter( - Recipie to remove a flatten operator. + + // m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter( + // Recipie to remove a flatten operator. - :param nodes: The flatten operator to remove. - :type nodes: list of :py:class:`aidge_core.Node` - )mydelimiter"); + // :param nodes: The flatten operator to remove. + // :type nodes: list of :py:class:`aidge_core.Node` + // )mydelimiter"); } } // namespace Aidge diff --git a/src/recipies/RemoveFlatten.cpp b/src/recipies/RemoveFlatten.cpp index fdfdbfd4a..452c32b92 100644 --- a/src/recipies/RemoveFlatten.cpp +++ b/src/recipies/RemoveFlatten.cpp @@ -18,33 +18,59 @@ // Graph Regex #include "aidge/graphmatching/GRegex.hpp" #include "aidge/graphmatching/NodeRegex.hpp" +//Graph Regex +#include "aidge/graphRegex/GraphRegex.hpp" namespace Aidge { - void removeFlatten(std::set<std::shared_ptr<Node>> nodes) { - assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); - std::shared_ptr<Node> flatten; - for (const auto& element : nodes) { - assert((element->type() == "FC" || element->type() == "Flatten") && "Wrong type for the nodes to replace"); - if (element->type() == "Flatten"){ - flatten = element; - } - } + void removeFlatten(std::shared_ptr<Node> flatten) { + // assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); + // std::shared_ptr<Node> flatten; + // for (const auto& element : nodes) { + // assert((element->type() == "FC" || element->type() == "Flatten") && "Wrong type for the nodes to replace"); + // if (element->type() == "Flatten"){ + // flatten = element; + // } + // } GraphView::replace({flatten}, {}); } + void removeFlatten(std::shared_ptr<MatchSolution> solution){ + + assert(solution->at("FC").size() == 1 && "Wrong number of nodes FC to replace\n"); + assert(solution->at("Flatten").size() == 1 && "Wrong number of nodes Flatten to replace\n"); + + for (const auto& flatten : solution->at("Flatten")) { + removeFlatten(flatten); + } + } + + + void removeFlatten(std::shared_ptr<GraphView> graphView){ - std::map<std::string,NodeRegex*> nodesRegex ; - nodesRegex["Flatten"] = new NodeRegex("Flatten"); - nodesRegex["FC"] = new NodeRegex("FC"); - std::vector<std::string> seqRegex; - seqRegex.push_back("Flatten->FC;"); - GRegex GReg(nodesRegex, seqRegex); - Match matches = GReg.match(graphView); - std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); - for (size_t i = 0; i < matches.getNbMatch(); ++i) { - removeFlatten(matchNodes[i]); + // std::map<std::string,NodeRegex*> nodesRegex ; + // nodesRegex["Flatten"] = new NodeRegex("Flatten"); + // nodesRegex["FC"] = new NodeRegex("FC"); + // std::vector<std::string> seqRegex; + // seqRegex.push_back("Flatten->FC;"); + // GRegex GReg(nodesRegex, seqRegex); + // Match matches = GReg.match(graphView); + // std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); + // for (size_t i = 0; i < matches.getNbMatch(); ++i) { + // removeFlatten(matchNodes[i]); + // } + + + std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); + regex->setNodeKey("Flatten","getType($) =='Flatten'"); + regex->setNodeKey("FC","getType($) =='FC'"); + regex->addQuery("Flatten->FC"); + + for (const auto& solution : regex->match(graphView)) { + removeFlatten(solution); } + + } } -- GitLab