From b331fa1a779beb5130c2499b52dcdb5dec751876 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Wed, 13 Sep 2023 06:24:11 +0000
Subject: [PATCH] [Recipies] Extend recipies so that they can also use
 schedulers as input.

---
 include/aidge/utils/Recipies.hpp            |  8 +++---
 python_binding/recipies/pybind_Recipies.cpp | 32 ++++++++++++---------
 src/recipies/FuseMulAdd.cpp                 | 11 +++----
 src/recipies/RemoveFlatten.cpp              | 20 ++++++++++++-
 4 files changed, 48 insertions(+), 23 deletions(-)

diff --git a/include/aidge/utils/Recipies.hpp b/include/aidge/utils/Recipies.hpp
index 4cbf8fd28..68bcf17ac 100644
--- a/include/aidge/utils/Recipies.hpp
+++ b/include/aidge/utils/Recipies.hpp
@@ -16,12 +16,12 @@
 #include "aidge/graph/GraphView.hpp"
 
 namespace Aidge{
-
 void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes);
-void removeFlatten(std::set<std::shared_ptr<Node>> nodes);
+void fuseMulAdd(std::shared_ptr<GraphView> graphView);
 
+void removeFlatten(std::set<std::shared_ptr<Node>> nodes);
+void removeFlatten(std::shared_ptr<GraphView> graphView);
 
 }
 
-
-#endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */
\ No newline at end of file
+#endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */
diff --git a/python_binding/recipies/pybind_Recipies.cpp b/python_binding/recipies/pybind_Recipies.cpp
index b4147dcb4..d1c392f74 100644
--- a/python_binding/recipies/pybind_Recipies.cpp
+++ b/python_binding/recipies/pybind_Recipies.cpp
@@ -20,24 +20,30 @@ namespace py = pybind11;
 
 namespace Aidge {
 void init_Recipies(py::module &m) {
-  m.def("fuse_mul_add", &fuseMulAdd, py::arg("nodes"), R"mydelimiter(
-    Recipie to Fuse MatMul and Add operators into an `aidge.FC` operator.
-    
-    Parameters
-    ----------
+  m.def("fuse_mul_add", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseMulAdd), py::arg("graph_view"), R"mydelimiter(
+    Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
+
+    :param graph_view: Graph view on which we want to apply the recipie
+    :type graph_view: :py:class:`aidge_core.GraphView`
+    )mydelimiter");
+  m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
+    Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
+
     :param nodes: The MatMul and Add nodes to fuse.
-    :type nodes: list of `aidge.node`
+    :type nodes: list of :py:class:`aidge_core.Node`
+    )mydelimiter");
+  m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter(
+    Recipie to remove a flatten operator.
 
+    :param graph_view: Graph view on which we want to apply the recipie
+    :type graph_view: :py:class:`aidge_core.GraphView`
     )mydelimiter");
-  m.def("remove_flatten", &removeFlatten, py::arg("nodes"), R"mydelimiter(
+  m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter(
     Recipie to remove a flatten operator.
-    
-    Parameters
-    ----------
-    :param nodes: The flatten operator to remove.
-    :type nodes: list of `aidge.node`
 
+    :param nodes: The flatten operator to remove.
+    :type nodes: list of :py:class:`aidge_core.Node`
     )mydelimiter");
-  
+
 }
 } // namespace Aidge
diff --git a/src/recipies/FuseMulAdd.cpp b/src/recipies/FuseMulAdd.cpp
index dc565bf0a..20ab34c7e 100644
--- a/src/recipies/FuseMulAdd.cpp
+++ b/src/recipies/FuseMulAdd.cpp
@@ -25,16 +25,16 @@ using namespace Aidge;
 
 /**
  * @brief Merge MatMul and Add Node into FC.
- * 
+ *
  * @param nodes Strict set of Node to merge.
  */
 void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
     // Fuse Mulmat & Add into FC
     // Inputs : old nodes (pointers on mul & add)
-    
+
     assert(nodes.size() == 2 && "Wrong number of nodes to replace\n");
     // Too bad we lose information on the type after matching, how to keep the information after matching (not only for the type) ?
-    
+
     // Step 0 : Assert the nodes types are correct to be fused
     std::shared_ptr<Node> add;
     std::shared_ptr<Node> matmul;
@@ -53,7 +53,7 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
     auto producer_add_bias = add->input(1);
     Tensor& bias_tensor = (producer_add_bias.first)->getOperator()->output(0);
 
-    // Instanciate FC  
+    // Instanciate FC
     //std::shared_ptr<Node> fc = FC(dim[0], false, "Fc");
     std::shared_ptr<Node> fc = std::make_shared<Node>(std::make_shared<FC_Op>(bias_tensor.dims()[0], false));
 
@@ -77,4 +77,5 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
     nodeToReplace->add(nodes);
     nodeToReplace->replaceWith({fc});
 
-}
\ No newline at end of file
+}
+
diff --git a/src/recipies/RemoveFlatten.cpp b/src/recipies/RemoveFlatten.cpp
index cc3c3324e..5bba4a716 100644
--- a/src/recipies/RemoveFlatten.cpp
+++ b/src/recipies/RemoveFlatten.cpp
@@ -15,10 +15,28 @@
 #include "aidge/graph/GraphView.hpp"
 #include "aidge/utils/Recipies.hpp"
 
+// Graph Regex
+#include "aidge/graphmatching/GRegex.hpp"
+#include "aidge/graphmatching/NodeRegex.hpp"
+
+
 namespace Aidge {
     void removeFlatten(std::set<std::shared_ptr<Node>> nodes) {
         auto g = std::make_shared<GraphView>();
         g->add(std::set<std::shared_ptr<Node>>({nodes}));
         g->replaceWith({});
     }
-}
\ No newline at end of file
+
+    void removeFlatten(std::shared_ptr<GraphView> graphView){
+        std::map<std::string,NodeRegex*> nodesRegex ;
+        nodesRegex["Flatten"] = new NodeRegex("Flatten");
+        std::vector<std::string> seqRegex;
+        seqRegex.push_back("Flatten;");
+        GRegex GReg(nodesRegex, seqRegex);
+        Match matches = GReg.match(graphView);
+        std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();
+        for (size_t i = 0; i < matches.getNbMatch(); ++i) {
+            removeFlatten(matchNodes[i]);
+        }
+    }
+}
-- 
GitLab