From 450b132960c4ff85520199f09746770984527fea Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Wed, 13 Nov 2024 16:34:27 +0100 Subject: [PATCH] Added simplify_graph() --- aidge_core/simplify_graph.py | 56 +++++++++++++++++++++++ include/aidge/graph/Matching.hpp | 4 +- include/aidge/recipes/Recipes.hpp | 14 ++++-- python_binding/graph/pybind_Matching.cpp | 30 ++++++------ python_binding/recipes/pybind_Recipes.cpp | 15 +++++- src/recipes/FuseToMetaOps.cpp | 9 +++- 6 files changed, 104 insertions(+), 24 deletions(-) create mode 100644 aidge_core/simplify_graph.py diff --git a/aidge_core/simplify_graph.py b/aidge_core/simplify_graph.py new file mode 100644 index 000000000..30ee04e6c --- /dev/null +++ b/aidge_core/simplify_graph.py @@ -0,0 +1,56 @@ +import numpy as np +import aidge_core + +def simplify_graph(graph: aidge_core.GraphView): + """ + Simplify a graph loaded from ONNX. + + :param graph: The GraphView to simplify. + :type graph: aidge_core.GraphView + """ + + def check_constant_producer(value): + def _check_constant_producer(node): + out = node.get_operator().get_output(0) + return (len(out) == 1 and np.isclose(out[0], value)) + return _check_constant_producer + + gm = aidge_core.SinglePassGraphMatching(graph) + gm.add_node_lambda("Constant_sqrt2", check_constant_producer(np.sqrt(2))) + gm.add_node_lambda("Constant_1", check_constant_producer(1)) + gm.add_node_lambda("Constant_0_5", check_constant_producer(0.5)) + + # Linear [from PyTorch ONNX] + aidge_core.fuse_to_metaops(gm, "MatMul-*>Add", "Linear") + + # LayerNorm [from PyTorch ONNX] + aidge_core.fuse_to_metaops(gm, "ReduceMean-*>Sub#1~>(Pow#1->ReduceMean-*>Add#1->Sqrt)-*>Div#1-*>Mul#1-*>Add#2;" + "Sub#1~*>Div#1;" + "Pow#1<1~Producer;" + "Add#1<*~Producer;" + "Mul#1<*~Producer;" + "Add#2<*~Producer;" + "Sub#1~>$", "LayerNorm") + + # ScaledDotProductAttention [from PyTorch ONNX] + aidge_core.fuse_to_metaops(gm, "MatMul->Div#1->Softmax-*>MatMul;" + "Div#1<1~Producer", "ScaledDotProductAttention") + + # MultiHeadAttention [from PyTorch ONNX] + aidge_core.fuse_to_metaops(gm, "ScaledDotProductAttention#1->Transpose->Reshape#1->Linear;" + "Reshape#1<1~Producer;" + "ScaledDotProductAttention#1<0-(Transpose<-Reshape#2<-Add#1);" + "ScaledDotProductAttention#1<1-(Transpose<-Reshape#3<-Add#2);" + "ScaledDotProductAttention#1<2-(Transpose<-Reshape#4<-Add#3);" + "Reshape#2<1~Producer;" + "Add#1<*-0-Split#1;" + "Add#2<*-1-Split#1;" + "Add#3<*-2-Split#1;" + "Split#1<-MatMul;" + "Split#1<1~Producer", "MultiHeadAttention") + + # GeLU [from PyTorch ONNX] + aidge_core.fuse_to_metaops(gm, "Div#1->Erf->Add#1-*>Mul->Mul#2;" + "Div#1<1~Producer[Constant_sqrt2];" + "Add#1<*~Producer[Constant_1];" + "Mul#2<*~Producer[Constant_0_5]", "GeLU") diff --git a/include/aidge/graph/Matching.hpp b/include/aidge/graph/Matching.hpp index 951aa6b29..b846af10b 100644 --- a/include/aidge/graph/Matching.hpp +++ b/include/aidge/graph/Matching.hpp @@ -154,13 +154,13 @@ public: */ std::set<MatchingResult> filterLonguestDisjoint(const std::set<MatchingResult>& matches); - inline void addNodeLambda(const std::string& name, bool(func)(const NodePtr&)) { + inline void addNodeLambda(const std::string& name, std::function<bool(const NodePtr&)> func) { mLambda[name] = func; } private: std::shared_ptr<GraphView> mGraph; - std::map<std::string, bool(*)(const NodePtr&)> mLambda; + std::map<std::string, std::function<bool(const NodePtr&)>> mLambda; /** * QUANTIFIER = '?' | '*' | '+' | ('{' [0-9]+ '}') diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp index 82ecc7d28..86c722b15 100644 --- a/include/aidge/recipes/Recipes.hpp +++ b/include/aidge/recipes/Recipes.hpp @@ -17,7 +17,7 @@ #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" -#include "aidge/graphRegex/matchFsm/MatchResult.hpp" +#include "aidge/graph/Matching.hpp" namespace Aidge { @@ -81,9 +81,6 @@ size_t removeIdentity(std::shared_ptr<GraphView> graph); */ void removeFlatten(std::shared_ptr<Node> flatten); - -void removeFlatten(std::shared_ptr<MatchSolution> solution); - /** * @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node. * @@ -151,6 +148,15 @@ void expandMetaOps(std::shared_ptr<GraphView> graph, bool recursive = false); */ void matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims); +/** + * Fuse each sub-graph matching a query in a Meta Operator. + * @param gm SinglePassGraphMatching containing the graph to manipulate + * @param query Sub-graph matching query + * @param type Type name of the resulting meta operators + * @return size_t Number of replacement +*/ +size_t fuseToMetaOps(SinglePassGraphMatching& gm, const std::string& query, const std::string& type = ""); + /** * Fuse each sub-graph matching a query in a Meta Operator. * @param graph Graph to manipulate diff --git a/python_binding/graph/pybind_Matching.cpp b/python_binding/graph/pybind_Matching.cpp index 94f2471c3..af3857981 100644 --- a/python_binding/graph/pybind_Matching.cpp +++ b/python_binding/graph/pybind_Matching.cpp @@ -10,6 +10,7 @@ ********************************************************************************/ #include <pybind11/pybind11.h> +#include <pybind11/functional.h> #include <pybind11/stl.h> #include <memory> #include <string> @@ -31,21 +32,20 @@ void init_SinglePassGraphMatching(py::module& m) { py::class_<Aidge::SinglePassGraphMatching>(m, "SinglePassGraphMatching") .def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph")) .def("match", - [](Aidge::SinglePassGraphMatching& self, const std::string& query, bool disjoint){ - // Note: Need to convert set to vector has MatchingResult is not hashable and - // set<MatchingResult> cannot be binded - std::set<Aidge::SinglePassGraphMatching::MatchingResult> set_res = self.match(query, disjoint); - std::vector<Aidge::SinglePassGraphMatching::MatchingResult> vec_res(set_res.begin(), set_res.end()); - return vec_res; - }, - py::arg("query"), py::arg("disjoint") = false, - R"mydelimiter( Matches a query by direct, single-pass parse and match. - :param query: The query string to search. - :param disjoint: If true, only keep the longest disjoint matches. - :return: A set of MatchingResult instances. - )mydelimiter"); - - + [](Aidge::SinglePassGraphMatching& self, const std::string& query, bool disjoint){ + // Note: Need to convert set to vector has MatchingResult is not hashable and + // set<MatchingResult> cannot be binded + std::set<Aidge::SinglePassGraphMatching::MatchingResult> set_res = self.match(query, disjoint); + std::vector<Aidge::SinglePassGraphMatching::MatchingResult> vec_res(set_res.begin(), set_res.end()); + return vec_res; + }, + py::arg("query"), py::arg("disjoint") = false, + R"mydelimiter( Matches a query by direct, single-pass parse and match. + :param query: The query string to search. + :param disjoint: If true, only keep the longest disjoint matches. + :return: A set of MatchingResult instances. + )mydelimiter") + .def("add_node_lambda", &SinglePassGraphMatching::addNodeLambda, py::arg("name"), py::arg("func")); } } // namespace Aidge diff --git a/python_binding/recipes/pybind_Recipes.cpp b/python_binding/recipes/pybind_Recipes.cpp index 6908cbd91..77f20b9d6 100644 --- a/python_binding/recipes/pybind_Recipes.cpp +++ b/python_binding/recipes/pybind_Recipes.cpp @@ -112,7 +112,20 @@ void init_Recipes(py::module &m) :type recursive: bool )mydelimiter"); - m.def("fuse_to_metaops", fuseToMetaOps, py::arg("graph_view"), py::arg("query"), py::arg("type") = "", R"mydelimiter( + m.def("fuse_to_metaops", py::overload_cast<SinglePassGraphMatching&, const std::string&, const std::string&>(fuseToMetaOps), py::arg("gm"), py::arg("query"), py::arg("type") = "", R"mydelimiter( + Fuse each sub-graph matching a query in a Meta Operator. + + :param gm: SinglePassGraphMatching containing the graph to manipulate + :type gm: :py:class:`aidge_core.SinglePassGraphMatching` + :param query: Sub-graph matching query + :type query: str + :param type: Type name of the resulting meta operators + :type type: str, optional + :return: Number of sub-graph actually fused in a Meta Operator. + :rtype: int + )mydelimiter"); + + m.def("fuse_to_metaops", py::overload_cast<std::shared_ptr<GraphView>, const std::string&, const std::string&>(fuseToMetaOps), py::arg("graph_view"), py::arg("query"), py::arg("type") = "", R"mydelimiter( Fuse each sub-graph matching a query in a Meta Operator. :param graph_view: Graph view on which we want to apply the recipe diff --git a/src/recipes/FuseToMetaOps.cpp b/src/recipes/FuseToMetaOps.cpp index 0ad5e5a1d..ac6536d7e 100644 --- a/src/recipes/FuseToMetaOps.cpp +++ b/src/recipes/FuseToMetaOps.cpp @@ -17,9 +17,9 @@ #include "aidge/operator/MetaOperator.hpp" #include "aidge/recipes/Recipes.hpp" -size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::string& query, const std::string& type) { +size_t Aidge::fuseToMetaOps(SinglePassGraphMatching& gm, const std::string& query, const std::string& type) { const auto metaType = (!type.empty()) ? type : query; - const auto matches = SinglePassGraphMatching(graphView).match(query); + const auto matches = gm.match(query); size_t nbReplaced = 0; for (const auto& match : matches) { @@ -48,3 +48,8 @@ size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::str Log::info("Replaced {} (out of {}) matching sub-graph with meta operators", nbReplaced, matches.size()); return nbReplaced; } + +size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::string& query, const std::string& type) { + SinglePassGraphMatching gm(graphView); + return fuseToMetaOps(gm, query, type); +} -- GitLab