diff --git a/include/aidge/utils/Recipies.hpp b/include/aidge/utils/Recipies.hpp index 4cbf8fd284bef314dbe28b19ebdae05172467bad..68bcf17ace039349ddc95f40a324de954763d663 100644 --- a/include/aidge/utils/Recipies.hpp +++ b/include/aidge/utils/Recipies.hpp @@ -16,12 +16,12 @@ #include "aidge/graph/GraphView.hpp" namespace Aidge{ - void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes); -void removeFlatten(std::set<std::shared_ptr<Node>> nodes); +void fuseMulAdd(std::shared_ptr<GraphView> graphView); +void removeFlatten(std::set<std::shared_ptr<Node>> nodes); +void removeFlatten(std::shared_ptr<GraphView> graphView); } - -#endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */ \ No newline at end of file +#endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */ diff --git a/python_binding/recipies/pybind_Recipies.cpp b/python_binding/recipies/pybind_Recipies.cpp index b4147dcb4fb82dbfe9f5b4605604725c6945ece9..d1c392f747cfb8bc7d848c0da05ef4635086280f 100644 --- a/python_binding/recipies/pybind_Recipies.cpp +++ b/python_binding/recipies/pybind_Recipies.cpp @@ -20,24 +20,30 @@ namespace py = pybind11; namespace Aidge { void init_Recipies(py::module &m) { - m.def("fuse_mul_add", &fuseMulAdd, py::arg("nodes"), R"mydelimiter( - Recipie to Fuse MatMul and Add operators into an `aidge.FC` operator. - - Parameters - ---------- + m.def("fuse_mul_add", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseMulAdd), py::arg("graph_view"), R"mydelimiter( + Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. + + :param graph_view: Graph view on which we want to apply the recipie + :type graph_view: :py:class:`aidge_core.GraphView` + )mydelimiter"); + m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( + Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. + :param nodes: The MatMul and Add nodes to fuse. - :type nodes: list of `aidge.node` + :type nodes: list of :py:class:`aidge_core.Node` + )mydelimiter"); + m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter( + Recipie to remove a flatten operator. + :param graph_view: Graph view on which we want to apply the recipie + :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); - m.def("remove_flatten", &removeFlatten, py::arg("nodes"), R"mydelimiter( + m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter( Recipie to remove a flatten operator. - - Parameters - ---------- - :param nodes: The flatten operator to remove. - :type nodes: list of `aidge.node` + :param nodes: The flatten operator to remove. + :type nodes: list of :py:class:`aidge_core.Node` )mydelimiter"); - + } } // namespace Aidge diff --git a/src/recipies/FuseMulAdd.cpp b/src/recipies/FuseMulAdd.cpp index dc565bf0acc7747d79ec12df973a82d86fc79503..20ab34c7e921b44a5c17b65ed7a101d9a9c34a59 100644 --- a/src/recipies/FuseMulAdd.cpp +++ b/src/recipies/FuseMulAdd.cpp @@ -25,16 +25,16 @@ using namespace Aidge; /** * @brief Merge MatMul and Add Node into FC. - * + * * @param nodes Strict set of Node to merge. */ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ // Fuse Mulmat & Add into FC // Inputs : old nodes (pointers on mul & add) - + assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); // Too bad we lose information on the type after matching, how to keep the information after matching (not only for the type) ? - + // Step 0 : Assert the nodes types are correct to be fused std::shared_ptr<Node> add; std::shared_ptr<Node> matmul; @@ -53,7 +53,7 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ auto producer_add_bias = add->input(1); Tensor& bias_tensor = (producer_add_bias.first)->getOperator()->output(0); - // Instanciate FC + // Instanciate FC //std::shared_ptr<Node> fc = FC(dim[0], false, "Fc"); std::shared_ptr<Node> fc = std::make_shared<Node>(std::make_shared<FC_Op>(bias_tensor.dims()[0], false)); @@ -77,4 +77,5 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ nodeToReplace->add(nodes); nodeToReplace->replaceWith({fc}); -} \ No newline at end of file +} + diff --git a/src/recipies/RemoveFlatten.cpp b/src/recipies/RemoveFlatten.cpp index cc3c3324e40636a1edcbc73cdc4a9dcfeec8a026..5bba4a71622a4d5df33970ee57b0459ed06f93d3 100644 --- a/src/recipies/RemoveFlatten.cpp +++ b/src/recipies/RemoveFlatten.cpp @@ -15,10 +15,28 @@ #include "aidge/graph/GraphView.hpp" #include "aidge/utils/Recipies.hpp" +// Graph Regex +#include "aidge/graphmatching/GRegex.hpp" +#include "aidge/graphmatching/NodeRegex.hpp" + + namespace Aidge { void removeFlatten(std::set<std::shared_ptr<Node>> nodes) { auto g = std::make_shared<GraphView>(); g->add(std::set<std::shared_ptr<Node>>({nodes})); g->replaceWith({}); } -} \ No newline at end of file + + void removeFlatten(std::shared_ptr<GraphView> graphView){ + std::map<std::string,NodeRegex*> nodesRegex ; + nodesRegex["Flatten"] = new NodeRegex("Flatten"); + std::vector<std::string> seqRegex; + seqRegex.push_back("Flatten;"); + GRegex GReg(nodesRegex, seqRegex); + Match matches = GReg.match(graphView); + std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); + for (size_t i = 0; i < matches.getNbMatch(); ++i) { + removeFlatten(matchNodes[i]); + } + } +}