From 19a41f8e0bcb53c013c081e027f4ab511cf6c7a9 Mon Sep 17 00:00:00 2001
From: vl241552 <vincent.lorrain@cea.fr>
Date: Thu, 9 Nov 2023 14:01:42 +0000
Subject: [PATCH] removeFlatten recipies and fix pybind

---
 include/aidge/utils/Recipies.hpp            |  6 +-
 python_binding/recipies/pybind_Recipies.cpp | 44 +++++++-------
 src/recipies/RemoveFlatten.cpp              | 64 +++++++++++++++------
 3 files changed, 74 insertions(+), 40 deletions(-)

diff --git a/include/aidge/utils/Recipies.hpp b/include/aidge/utils/Recipies.hpp
index 4b21a4f59..3236e7bf3 100644
--- a/include/aidge/utils/Recipies.hpp
+++ b/include/aidge/utils/Recipies.hpp
@@ -50,7 +50,11 @@ void fuseMulAdd(std::shared_ptr<GraphView> graphView);
  *
  * @param nodes Strict set of Node to merge.
  */
-void removeFlatten(std::set<std::shared_ptr<Node>> nodes);
+void removeFlatten(std::shared_ptr<Node> flatten);
+
+
+void removeFlatten(std::shared_ptr<MatchSolution> solution);
+
 /**
  * @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node.
  *
diff --git a/python_binding/recipies/pybind_Recipies.cpp b/python_binding/recipies/pybind_Recipies.cpp
index 93c131ef7..87abf3207 100644
--- a/python_binding/recipies/pybind_Recipies.cpp
+++ b/python_binding/recipies/pybind_Recipies.cpp
@@ -28,12 +28,13 @@ void init_Recipies(py::module &m) {
     :param graph_view: Graph view on which we want to apply the recipie
     :type graph_view: :py:class:`aidge_core.GraphView`
     )mydelimiter");
-  m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
-    Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
+    
+  // m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
+  //   Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
 
-    :param nodes: The MatMul and Add nodes to fuse.
-    :type nodes: list of :py:class:`aidge_core.Node`
-    )mydelimiter");
+  //   :param nodes: The MatMul and Add nodes to fuse.
+  //   :type nodes: list of :py:class:`aidge_core.Node`
+  //   )mydelimiter");
 
   m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter(
     Recipie to remove a flatten operator.
@@ -41,18 +42,20 @@ void init_Recipies(py::module &m) {
     :param graph_view: Graph view on which we want to apply the recipie
     :type graph_view: :py:class:`aidge_core.GraphView`
     )mydelimiter");
-  m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter(
-    Recipie to remove a flatten operator.
 
-    :param nodes: The flatten operator to remove.
-    :type nodes: list of :py:class:`aidge_core.Node`
-    )mydelimiter");
-  m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
-    Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
+  // m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter(
+  //   Recipie to remove a flatten operator.
 
-    :param nodes: The MatMul and Add nodes to fuse.
-    :type nodes: list of :py:class:`aidge_core.Node`
-    )mydelimiter");
+  //   :param nodes: The flatten operator to remove.
+  //   :type nodes: list of :py:class:`aidge_core.Node`
+  //   )mydelimiter");
+
+  // m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
+  //   Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
+
+  //   :param nodes: The MatMul and Add nodes to fuse.
+  //   :type nodes: list of :py:class:`aidge_core.Node`
+  //   )mydelimiter");
 
   m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter(
     Recipie to remove a flatten operator.
@@ -60,11 +63,12 @@ void init_Recipies(py::module &m) {
     :param graph_view: Graph view on which we want to apply the recipie
     :type graph_view: :py:class:`aidge_core.GraphView`
     )mydelimiter");
-  m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter(
-    Recipie to remove a flatten operator.
+    
+  // m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter(
+  //   Recipie to remove a flatten operator.
 
-    :param nodes: The flatten operator to remove.
-    :type nodes: list of :py:class:`aidge_core.Node`
-    )mydelimiter");
+  //   :param nodes: The flatten operator to remove.
+  //   :type nodes: list of :py:class:`aidge_core.Node`
+  //   )mydelimiter");
 }
 } // namespace Aidge
diff --git a/src/recipies/RemoveFlatten.cpp b/src/recipies/RemoveFlatten.cpp
index fdfdbfd4a..452c32b92 100644
--- a/src/recipies/RemoveFlatten.cpp
+++ b/src/recipies/RemoveFlatten.cpp
@@ -18,33 +18,59 @@
 // Graph Regex
 #include "aidge/graphmatching/GRegex.hpp"
 #include "aidge/graphmatching/NodeRegex.hpp"
+//Graph Regex
+#include "aidge/graphRegex/GraphRegex.hpp"
 
 
 namespace Aidge {
-    void removeFlatten(std::set<std::shared_ptr<Node>> nodes) {
-        assert(nodes.size() == 2 && "Wrong number of nodes to replace\n");
-        std::shared_ptr<Node> flatten;
-        for (const auto& element : nodes) {
-            assert((element->type() == "FC" || element->type() == "Flatten") && "Wrong type for the nodes to replace");
-            if (element->type() == "Flatten"){
-                flatten = element;
-            }
-        }
+    void removeFlatten(std::shared_ptr<Node> flatten) {
+        // assert(nodes.size() == 2 && "Wrong number of nodes to replace\n");
+        // std::shared_ptr<Node> flatten;
+        // for (const auto& element : nodes) {
+        //     assert((element->type() == "FC" || element->type() == "Flatten") && "Wrong type for the nodes to replace");
+        //     if (element->type() == "Flatten"){
+        //         flatten = element;
+        //     }
+        // }
 
         GraphView::replace({flatten}, {});
     }
 
+    void removeFlatten(std::shared_ptr<MatchSolution> solution){
+
+        assert(solution->at("FC").size() == 1 && "Wrong number of nodes FC to replace\n");
+        assert(solution->at("Flatten").size() == 1 && "Wrong number of nodes Flatten to replace\n");
+
+        for (const auto& flatten : solution->at("Flatten")) {
+            removeFlatten(flatten);
+        }
+    }
+
+
+
     void removeFlatten(std::shared_ptr<GraphView> graphView){
-        std::map<std::string,NodeRegex*> nodesRegex ;
-        nodesRegex["Flatten"] = new NodeRegex("Flatten");
-        nodesRegex["FC"] = new NodeRegex("FC");
-        std::vector<std::string> seqRegex;
-        seqRegex.push_back("Flatten->FC;");
-        GRegex GReg(nodesRegex, seqRegex);
-        Match matches = GReg.match(graphView);
-        std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();
-        for (size_t i = 0; i < matches.getNbMatch(); ++i) {
-            removeFlatten(matchNodes[i]);
+        // std::map<std::string,NodeRegex*> nodesRegex ;
+        // nodesRegex["Flatten"] = new NodeRegex("Flatten");
+        // nodesRegex["FC"] = new NodeRegex("FC");
+        // std::vector<std::string> seqRegex;
+        // seqRegex.push_back("Flatten->FC;");
+        // GRegex GReg(nodesRegex, seqRegex);
+        // Match matches = GReg.match(graphView);
+        // std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();
+        // for (size_t i = 0; i < matches.getNbMatch(); ++i) {
+        //     removeFlatten(matchNodes[i]);
+        // }
+
+
+        std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
+        regex->setNodeKey("Flatten","getType($) =='Flatten'");
+        regex->setNodeKey("FC","getType($) =='FC'");
+        regex->addQuery("Flatten->FC");
+
+        for (const auto& solution : regex->match(graphView)) {
+            removeFlatten(solution);
         }
+
+
     }
 }
-- 
GitLab