From b331fa1a779beb5130c2499b52dcdb5dec751876 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Wed, 13 Sep 2023 06:24:11 +0000 Subject: [PATCH] [Recipies] Extend recipies so that they can also use schedulers as input. --- include/aidge/utils/Recipies.hpp | 8 +++--- python_binding/recipies/pybind_Recipies.cpp | 32 ++++++++++++--------- src/recipies/FuseMulAdd.cpp | 11 +++---- src/recipies/RemoveFlatten.cpp | 20 ++++++++++++- 4 files changed, 48 insertions(+), 23 deletions(-) diff --git a/include/aidge/utils/Recipies.hpp b/include/aidge/utils/Recipies.hpp index 4cbf8fd28..68bcf17ac 100644 --- a/include/aidge/utils/Recipies.hpp +++ b/include/aidge/utils/Recipies.hpp @@ -16,12 +16,12 @@ #include "aidge/graph/GraphView.hpp" namespace Aidge{ - void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes); -void removeFlatten(std::set<std::shared_ptr<Node>> nodes); +void fuseMulAdd(std::shared_ptr<GraphView> graphView); +void removeFlatten(std::set<std::shared_ptr<Node>> nodes); +void removeFlatten(std::shared_ptr<GraphView> graphView); } - -#endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */ \ No newline at end of file +#endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */ diff --git a/python_binding/recipies/pybind_Recipies.cpp b/python_binding/recipies/pybind_Recipies.cpp index b4147dcb4..d1c392f74 100644 --- a/python_binding/recipies/pybind_Recipies.cpp +++ b/python_binding/recipies/pybind_Recipies.cpp @@ -20,24 +20,30 @@ namespace py = pybind11; namespace Aidge { void init_Recipies(py::module &m) { - m.def("fuse_mul_add", &fuseMulAdd, py::arg("nodes"), R"mydelimiter( - Recipie to Fuse MatMul and Add operators into an `aidge.FC` operator. - - Parameters - ---------- + m.def("fuse_mul_add", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseMulAdd), py::arg("graph_view"), R"mydelimiter( + Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. + + :param graph_view: Graph view on which we want to apply the recipie + :type graph_view: :py:class:`aidge_core.GraphView` + )mydelimiter"); + m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( + Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. + :param nodes: The MatMul and Add nodes to fuse. - :type nodes: list of `aidge.node` + :type nodes: list of :py:class:`aidge_core.Node` + )mydelimiter"); + m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter( + Recipie to remove a flatten operator. + :param graph_view: Graph view on which we want to apply the recipie + :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); - m.def("remove_flatten", &removeFlatten, py::arg("nodes"), R"mydelimiter( + m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter( Recipie to remove a flatten operator. - - Parameters - ---------- - :param nodes: The flatten operator to remove. - :type nodes: list of `aidge.node` + :param nodes: The flatten operator to remove. + :type nodes: list of :py:class:`aidge_core.Node` )mydelimiter"); - + } } // namespace Aidge diff --git a/src/recipies/FuseMulAdd.cpp b/src/recipies/FuseMulAdd.cpp index dc565bf0a..20ab34c7e 100644 --- a/src/recipies/FuseMulAdd.cpp +++ b/src/recipies/FuseMulAdd.cpp @@ -25,16 +25,16 @@ using namespace Aidge; /** * @brief Merge MatMul and Add Node into FC. - * + * * @param nodes Strict set of Node to merge. */ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ // Fuse Mulmat & Add into FC // Inputs : old nodes (pointers on mul & add) - + assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); // Too bad we lose information on the type after matching, how to keep the information after matching (not only for the type) ? - + // Step 0 : Assert the nodes types are correct to be fused std::shared_ptr<Node> add; std::shared_ptr<Node> matmul; @@ -53,7 +53,7 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ auto producer_add_bias = add->input(1); Tensor& bias_tensor = (producer_add_bias.first)->getOperator()->output(0); - // Instanciate FC + // Instanciate FC //std::shared_ptr<Node> fc = FC(dim[0], false, "Fc"); std::shared_ptr<Node> fc = std::make_shared<Node>(std::make_shared<FC_Op>(bias_tensor.dims()[0], false)); @@ -77,4 +77,5 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ nodeToReplace->add(nodes); nodeToReplace->replaceWith({fc}); -} \ No newline at end of file +} + diff --git a/src/recipies/RemoveFlatten.cpp b/src/recipies/RemoveFlatten.cpp index cc3c3324e..5bba4a716 100644 --- a/src/recipies/RemoveFlatten.cpp +++ b/src/recipies/RemoveFlatten.cpp @@ -15,10 +15,28 @@ #include "aidge/graph/GraphView.hpp" #include "aidge/utils/Recipies.hpp" +// Graph Regex +#include "aidge/graphmatching/GRegex.hpp" +#include "aidge/graphmatching/NodeRegex.hpp" + + namespace Aidge { void removeFlatten(std::set<std::shared_ptr<Node>> nodes) { auto g = std::make_shared<GraphView>(); g->add(std::set<std::shared_ptr<Node>>({nodes})); g->replaceWith({}); } -} \ No newline at end of file + + void removeFlatten(std::shared_ptr<GraphView> graphView){ + std::map<std::string,NodeRegex*> nodesRegex ; + nodesRegex["Flatten"] = new NodeRegex("Flatten"); + std::vector<std::string> seqRegex; + seqRegex.push_back("Flatten;"); + GRegex GReg(nodesRegex, seqRegex); + Match matches = GReg.match(graphView); + std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); + for (size_t i = 0; i < matches.getNbMatch(); ++i) { + removeFlatten(matchNodes[i]); + } + } +} -- GitLab