From 8024b47f20c81b946988ec41d763b75170ab43a0 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Fri, 15 Sep 2023 07:39:12 +0000
Subject: [PATCH] Add unittest for fuse Mul + Add -> FC and a method with
 graphView as input.

---
 aidge_core/unit_tests/test_recipies.py | 41 +++++++++++++++++++++++-
 include/aidge/utils/Recipies.hpp       | 43 ++++++++++++++++++++++++++
 src/recipies/FuseMulAdd.cpp            | 27 +++++++++++-----
 3 files changed, 103 insertions(+), 8 deletions(-)

diff --git a/aidge_core/unit_tests/test_recipies.py b/aidge_core/unit_tests/test_recipies.py
index 96ed5c42c..7bdb1f48b 100644
--- a/aidge_core/unit_tests/test_recipies.py
+++ b/aidge_core/unit_tests/test_recipies.py
@@ -21,7 +21,7 @@ class test_parameters(unittest.TestCase):
     def tearDown(self):
         pass
 
-    def test_conv(self):
+    def test_remove_flatten(self):
         graph_view = aidge_core.sequential([
             aidge_core.GenericOperator("Flatten", 1, 1, 1, name="Flatten0"),
             aidge_core.FC(50, name='0')
@@ -33,6 +33,45 @@ class test_parameters(unittest.TestCase):
 
         self.assertTrue(all([i in old_nodes for i in graph_view.get_nodes()]))
 
+    def test_fuse_matmul_add(self):
+        matmul0 = aidge_core.GenericOperator("MatMul", 1, 2, 1, name="MatMul0")
+        add0 = aidge_core.Add(name="Add0")
+        matmul1 = aidge_core.GenericOperator("MatMul", 1, 2, 1, name="MatMul1")
+        add1 = aidge_core.Add(name="Add1")
+
+        graph_view = aidge_core.sequential([matmul0, add0, matmul1, add1])
+
+        w0 = aidge_core.Producer([1, 1], name="W0")
+        w0.add_child(matmul0, 0, 1)
+        graph_view.add(w0)
+
+        b0 = aidge_core.Producer([1], name="B0")
+        b0.add_child(add0, 0, 1)
+        graph_view.add(b0)
+
+        w1 = aidge_core.Producer([1, 1], name="W1")
+        w1.add_child(matmul1, 0, 1)
+        graph_view.add(w1)
+
+        b1 = aidge_core.Producer([1], name="B1")
+        b1.add_child(add1, 0, 1)
+        graph_view.add(b1)
+
+        graph_view.save("matmul")
+        old_nodes = graph_view.get_nodes()
+        aidge_core.fuse_mul_add(graph_view)
+
+        self.assertTrue(len(graph_view.get_nodes()) == len(old_nodes) - 2)
+        self.assertTrue("MatMul0" not in [i.name for i in graph_view.get_nodes()])
+        self.assertTrue("Add0" not in [i.name for i in graph_view.get_nodes()])
+        self.assertTrue("MatMul1" not in [i.name for i in graph_view.get_nodes()])
+        self.assertTrue("Add1" not in [i.name for i in graph_view.get_nodes()])
+
+        self.assertTrue("W0" in [i.name for i in graph_view.get_nodes()])
+        self.assertTrue("B0" in [i.name for i in graph_view.get_nodes()])
+        self.assertTrue("W1" in [i.name for i in graph_view.get_nodes()])
+        self.assertTrue("B1" in [i.name for i in graph_view.get_nodes()])
+        # TODO : Vérifier que FC bien crée
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/include/aidge/utils/Recipies.hpp b/include/aidge/utils/Recipies.hpp
index 68bcf17ac..894e56fae 100644
--- a/include/aidge/utils/Recipies.hpp
+++ b/include/aidge/utils/Recipies.hpp
@@ -16,11 +16,54 @@
 #include "aidge/graph/GraphView.hpp"
 
 namespace Aidge{
+
+// FUSE MATMUL + ADD -> FC
+
+/**
+ * @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node.
+ *
+ * @param nodes Strict set of Node to merge.
+ */
 void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes);
+/**
+ * @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node.
+ *
+ * @param graphView Graph view to use graph matching on, in order to apply transfomrations.
+ */
 void fuseMulAdd(std::shared_ptr<GraphView> graphView);
 
+
+// REMOVE FLATTEN + FC -> FC
+
+/**
+ * @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node.
+ *
+ * @param nodes Strict set of Node to merge.
+ */
 void removeFlatten(std::set<std::shared_ptr<Node>> nodes);
+/**
+ * @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node.
+ *
+ * @param graphView Graph view to use graph matching on, in order to apply transfomrations.
+ */
 void removeFlatten(std::shared_ptr<GraphView> graphView);
+ 
+// FUSE BN + FC || CONV -> FC || CONV
+
+/**
+ * @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes.
+ * Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/
+ *
+ * @param nodes Strict set of Node to merge.
+ */
+void fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes);
+/**
+ * @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes.
+ * Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/
+ *
+ * @param graphView Graph view to use graph matching on, in order to apply transfomrations.
+ */
+void fuseBatchNorm(std::shared_ptr<GraphView> graphView);
 
 }
 
diff --git a/src/recipies/FuseMulAdd.cpp b/src/recipies/FuseMulAdd.cpp
index 20ab34c7e..068e41226 100644
--- a/src/recipies/FuseMulAdd.cpp
+++ b/src/recipies/FuseMulAdd.cpp
@@ -20,14 +20,11 @@
 #include "aidge/graph/Node.hpp"
 #include "aidge/operator/Producer.hpp"
 #include "aidge/operator/GenericOperator.hpp"
-
+// Graph Regex
+#include "aidge/graphmatching/GRegex.hpp"
+#include "aidge/graphmatching/NodeRegex.hpp"
 using namespace Aidge;
 
-/**
- * @brief Merge MatMul and Add Node into FC.
- *
- * @param nodes Strict set of Node to merge.
- */
 void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
     // Fuse Mulmat & Add into FC
     // Inputs : old nodes (pointers on mul & add)
@@ -61,10 +58,12 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
     // link weights & bias
     if (matmul->getParents(1)==nullptr) {
         matmul->getParents(0)->addChild(fc, 0, 1);
+        printf("Matmul out[1] == nullptr !\n");
     } else {
+        printf("Matmul out[1] != nullptr !\n");
         if (matmul->getParents(0)!=nullptr)
             matmul->getParents(0)->addChild(fc, 0, 0);
-        matmul->getParents(1)->addChild(fc, 0, 1);
+        matmul->input(1).first->addChild(fc, 0, 1);
     }
     (producer_add_bias.first)->addChild(fc,0,2);
 
@@ -79,3 +78,17 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
 
 }
 
+void Aidge::fuseMulAdd(std::shared_ptr<GraphView> graphView){
+
+    std::map<std::string,NodeRegex*> nodesRegex ;
+    nodesRegex["MatMul"] = new NodeRegex("MatMul");
+    nodesRegex["Add"] = new NodeRegex("Add");
+    std::vector<std::string> seqRegex;
+    seqRegex.push_back("MatMul -> Add;");
+    GRegex GReg(nodesRegex, seqRegex);
+    Match matches = GReg.match(graphView);
+    std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();
+    for (size_t i = 0; i < matches.getNbMatch(); ++i) {
+        fuseMulAdd(matchNodes[i]);
+    }
+}
-- 
GitLab