From 450b132960c4ff85520199f09746770984527fea Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Wed, 13 Nov 2024 16:34:27 +0100
Subject: [PATCH] Added simplify_graph()

---
 aidge_core/simplify_graph.py              | 56 +++++++++++++++++++++++
 include/aidge/graph/Matching.hpp          |  4 +-
 include/aidge/recipes/Recipes.hpp         | 14 ++++--
 python_binding/graph/pybind_Matching.cpp  | 30 ++++++------
 python_binding/recipes/pybind_Recipes.cpp | 15 +++++-
 src/recipes/FuseToMetaOps.cpp             |  9 +++-
 6 files changed, 104 insertions(+), 24 deletions(-)
 create mode 100644 aidge_core/simplify_graph.py

diff --git a/aidge_core/simplify_graph.py b/aidge_core/simplify_graph.py
new file mode 100644
index 000000000..30ee04e6c
--- /dev/null
+++ b/aidge_core/simplify_graph.py
@@ -0,0 +1,56 @@
+import numpy as np
+import aidge_core
+
+def simplify_graph(graph: aidge_core.GraphView):
+    """
+    Simplify a graph loaded from ONNX.
+
+    :param graph: The GraphView to simplify.
+    :type graph: aidge_core.GraphView
+    """
+
+    def check_constant_producer(value):
+        def _check_constant_producer(node):
+            out = node.get_operator().get_output(0)
+            return (len(out) == 1 and np.isclose(out[0], value))
+        return _check_constant_producer
+
+    gm = aidge_core.SinglePassGraphMatching(graph)
+    gm.add_node_lambda("Constant_sqrt2", check_constant_producer(np.sqrt(2)))
+    gm.add_node_lambda("Constant_1", check_constant_producer(1))
+    gm.add_node_lambda("Constant_0_5", check_constant_producer(0.5))
+
+    # Linear [from PyTorch ONNX]
+    aidge_core.fuse_to_metaops(gm, "MatMul-*>Add", "Linear")
+
+    # LayerNorm [from PyTorch ONNX]
+    aidge_core.fuse_to_metaops(gm, "ReduceMean-*>Sub#1~>(Pow#1->ReduceMean-*>Add#1->Sqrt)-*>Div#1-*>Mul#1-*>Add#2;"
+                                   "Sub#1~*>Div#1;"
+                                   "Pow#1<1~Producer;"
+                                   "Add#1<*~Producer;"
+                                   "Mul#1<*~Producer;"
+                                   "Add#2<*~Producer;"
+                                   "Sub#1~>$", "LayerNorm")
+
+    # ScaledDotProductAttention [from PyTorch ONNX]
+    aidge_core.fuse_to_metaops(gm, "MatMul->Div#1->Softmax-*>MatMul;"
+                                   "Div#1<1~Producer", "ScaledDotProductAttention")
+
+    # MultiHeadAttention [from PyTorch ONNX]
+    aidge_core.fuse_to_metaops(gm, "ScaledDotProductAttention#1->Transpose->Reshape#1->Linear;"
+                                   "Reshape#1<1~Producer;"
+                                   "ScaledDotProductAttention#1<0-(Transpose<-Reshape#2<-Add#1);"
+                                   "ScaledDotProductAttention#1<1-(Transpose<-Reshape#3<-Add#2);"
+                                   "ScaledDotProductAttention#1<2-(Transpose<-Reshape#4<-Add#3);"
+                                   "Reshape#2<1~Producer;"
+                                   "Add#1<*-0-Split#1;"
+                                   "Add#2<*-1-Split#1;"
+                                   "Add#3<*-2-Split#1;"
+                                   "Split#1<-MatMul;"
+                                   "Split#1<1~Producer", "MultiHeadAttention")
+
+    # GeLU [from PyTorch ONNX]
+    aidge_core.fuse_to_metaops(gm, "Div#1->Erf->Add#1-*>Mul->Mul#2;"
+                                   "Div#1<1~Producer[Constant_sqrt2];"
+                                   "Add#1<*~Producer[Constant_1];"
+                                   "Mul#2<*~Producer[Constant_0_5]", "GeLU")
diff --git a/include/aidge/graph/Matching.hpp b/include/aidge/graph/Matching.hpp
index 951aa6b29..b846af10b 100644
--- a/include/aidge/graph/Matching.hpp
+++ b/include/aidge/graph/Matching.hpp
@@ -154,13 +154,13 @@ public:
     */
     std::set<MatchingResult> filterLonguestDisjoint(const std::set<MatchingResult>& matches);
 
-    inline void addNodeLambda(const std::string& name, bool(func)(const NodePtr&)) {
+    inline void addNodeLambda(const std::string& name, std::function<bool(const NodePtr&)> func) {
         mLambda[name] = func;
     }
 
 private:
     std::shared_ptr<GraphView> mGraph;
-    std::map<std::string, bool(*)(const NodePtr&)> mLambda;
+    std::map<std::string, std::function<bool(const NodePtr&)>> mLambda;
 
     /**
      * QUANTIFIER = '?' | '*' | '+' | ('{' [0-9]+ '}')
diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp
index 82ecc7d28..86c722b15 100644
--- a/include/aidge/recipes/Recipes.hpp
+++ b/include/aidge/recipes/Recipes.hpp
@@ -17,7 +17,7 @@
 
 #include "aidge/graph/Node.hpp"
 #include "aidge/graph/GraphView.hpp"
-#include "aidge/graphRegex/matchFsm/MatchResult.hpp"
+#include "aidge/graph/Matching.hpp"
 
 
 namespace Aidge {
@@ -81,9 +81,6 @@ size_t removeIdentity(std::shared_ptr<GraphView> graph);
  */
 void removeFlatten(std::shared_ptr<Node> flatten);
 
-
-void removeFlatten(std::shared_ptr<MatchSolution> solution);
-
 /**
  * @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node.
  *
@@ -151,6 +148,15 @@ void expandMetaOps(std::shared_ptr<GraphView> graph, bool recursive = false);
  */
 void matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims);
 
+/**
+ * Fuse each sub-graph matching a query in a Meta Operator.
+ * @param gm SinglePassGraphMatching containing the graph to manipulate
+ * @param query Sub-graph matching query
+ * @param type Type name of the resulting meta operators
+ * @return size_t Number of replacement
+*/
+size_t fuseToMetaOps(SinglePassGraphMatching& gm, const std::string& query, const std::string& type = "");
+
 /**
  * Fuse each sub-graph matching a query in a Meta Operator.
  * @param graph Graph to manipulate
diff --git a/python_binding/graph/pybind_Matching.cpp b/python_binding/graph/pybind_Matching.cpp
index 94f2471c3..af3857981 100644
--- a/python_binding/graph/pybind_Matching.cpp
+++ b/python_binding/graph/pybind_Matching.cpp
@@ -10,6 +10,7 @@
  ********************************************************************************/
 
 #include <pybind11/pybind11.h>
+#include <pybind11/functional.h>
 #include <pybind11/stl.h>
 #include <memory>
 #include <string>
@@ -31,21 +32,20 @@ void init_SinglePassGraphMatching(py::module& m) {
     py::class_<Aidge::SinglePassGraphMatching>(m, "SinglePassGraphMatching") 
         .def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph"))
         .def("match", 
-        [](Aidge::SinglePassGraphMatching& self, const std::string& query, bool disjoint){
-            // Note: Need to convert set to vector has MatchingResult is not hashable and 
-            // set<MatchingResult> cannot be binded
-            std::set<Aidge::SinglePassGraphMatching::MatchingResult> set_res = self.match(query, disjoint);
-            std::vector<Aidge::SinglePassGraphMatching::MatchingResult> vec_res(set_res.begin(), set_res.end());
-            return vec_res;
-        },
-        py::arg("query"), py::arg("disjoint") = false, 
-        R"mydelimiter( Matches a query by direct, single-pass parse and match.
-        :param query: The query string to search.
-        :param disjoint: If true, only keep the longest disjoint matches.
-        :return: A set of MatchingResult instances.
-        )mydelimiter");
-
-
+            [](Aidge::SinglePassGraphMatching& self, const std::string& query, bool disjoint){
+                // Note: Need to convert set to vector has MatchingResult is not hashable and 
+                // set<MatchingResult> cannot be binded
+                std::set<Aidge::SinglePassGraphMatching::MatchingResult> set_res = self.match(query, disjoint);
+                std::vector<Aidge::SinglePassGraphMatching::MatchingResult> vec_res(set_res.begin(), set_res.end());
+                return vec_res;
+            },
+            py::arg("query"), py::arg("disjoint") = false, 
+            R"mydelimiter( Matches a query by direct, single-pass parse and match.
+            :param query: The query string to search.
+            :param disjoint: If true, only keep the longest disjoint matches.
+            :return: A set of MatchingResult instances.
+            )mydelimiter")
+        .def("add_node_lambda", &SinglePassGraphMatching::addNodeLambda, py::arg("name"), py::arg("func"));
 
 }
 }  // namespace Aidge
diff --git a/python_binding/recipes/pybind_Recipes.cpp b/python_binding/recipes/pybind_Recipes.cpp
index 6908cbd91..77f20b9d6 100644
--- a/python_binding/recipes/pybind_Recipes.cpp
+++ b/python_binding/recipes/pybind_Recipes.cpp
@@ -112,7 +112,20 @@ void init_Recipes(py::module &m)
     :type recursive: bool
     )mydelimiter");
 
-  m.def("fuse_to_metaops", fuseToMetaOps, py::arg("graph_view"), py::arg("query"), py::arg("type") = "", R"mydelimiter(
+  m.def("fuse_to_metaops", py::overload_cast<SinglePassGraphMatching&, const std::string&, const std::string&>(fuseToMetaOps), py::arg("gm"), py::arg("query"), py::arg("type") = "", R"mydelimiter(
+    Fuse each sub-graph matching a query in a Meta Operator.
+
+    :param gm: SinglePassGraphMatching containing the graph to manipulate
+    :type gm: :py:class:`aidge_core.SinglePassGraphMatching`
+    :param query: Sub-graph matching query
+    :type query: str
+    :param type: Type name of the resulting meta operators
+    :type type: str, optional
+    :return: Number of sub-graph actually fused in a Meta Operator.
+    :rtype: int
+    )mydelimiter");
+
+  m.def("fuse_to_metaops", py::overload_cast<std::shared_ptr<GraphView>, const std::string&, const std::string&>(fuseToMetaOps), py::arg("graph_view"), py::arg("query"), py::arg("type") = "", R"mydelimiter(
     Fuse each sub-graph matching a query in a Meta Operator.
 
     :param graph_view: Graph view on which we want to apply the recipe
diff --git a/src/recipes/FuseToMetaOps.cpp b/src/recipes/FuseToMetaOps.cpp
index 0ad5e5a1d..ac6536d7e 100644
--- a/src/recipes/FuseToMetaOps.cpp
+++ b/src/recipes/FuseToMetaOps.cpp
@@ -17,9 +17,9 @@
 #include "aidge/operator/MetaOperator.hpp"
 #include "aidge/recipes/Recipes.hpp"
 
-size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::string& query, const std::string& type) {
+size_t Aidge::fuseToMetaOps(SinglePassGraphMatching& gm, const std::string& query, const std::string& type) {
     const auto metaType = (!type.empty()) ? type : query;
-    const auto matches = SinglePassGraphMatching(graphView).match(query);
+    const auto matches = gm.match(query);
 
     size_t nbReplaced = 0;
     for (const auto& match : matches) {
@@ -48,3 +48,8 @@ size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::str
     Log::info("Replaced {} (out of {}) matching sub-graph with meta operators", nbReplaced, matches.size());
     return nbReplaced;
 }
+
+size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::string& query, const std::string& type) {
+    SinglePassGraphMatching gm(graphView);
+    return fuseToMetaOps(gm, query, type);
+}
-- 
GitLab