diff --git a/aidge_core/unit_tests/test_recipies.py b/aidge_core/unit_tests/test_recipies.py index 96ed5c42ce1d0dc557f8b9c0f12178e4b8a874dd..7bdb1f48b7498becb318d6b3eccd850f7f375623 100644 --- a/aidge_core/unit_tests/test_recipies.py +++ b/aidge_core/unit_tests/test_recipies.py @@ -21,7 +21,7 @@ class test_parameters(unittest.TestCase): def tearDown(self): pass - def test_conv(self): + def test_remove_flatten(self): graph_view = aidge_core.sequential([ aidge_core.GenericOperator("Flatten", 1, 1, 1, name="Flatten0"), aidge_core.FC(50, name='0') @@ -33,6 +33,45 @@ class test_parameters(unittest.TestCase): self.assertTrue(all([i in old_nodes for i in graph_view.get_nodes()])) + def test_fuse_matmul_add(self): + matmul0 = aidge_core.GenericOperator("MatMul", 1, 2, 1, name="MatMul0") + add0 = aidge_core.Add(name="Add0") + matmul1 = aidge_core.GenericOperator("MatMul", 1, 2, 1, name="MatMul1") + add1 = aidge_core.Add(name="Add1") + + graph_view = aidge_core.sequential([matmul0, add0, matmul1, add1]) + + w0 = aidge_core.Producer([1, 1], name="W0") + w0.add_child(matmul0, 0, 1) + graph_view.add(w0) + + b0 = aidge_core.Producer([1], name="B0") + b0.add_child(add0, 0, 1) + graph_view.add(b0) + + w1 = aidge_core.Producer([1, 1], name="W1") + w1.add_child(matmul1, 0, 1) + graph_view.add(w1) + + b1 = aidge_core.Producer([1], name="B1") + b1.add_child(add1, 0, 1) + graph_view.add(b1) + + graph_view.save("matmul") + old_nodes = graph_view.get_nodes() + aidge_core.fuse_mul_add(graph_view) + + self.assertTrue(len(graph_view.get_nodes()) == len(old_nodes) - 2) + self.assertTrue("MatMul0" not in [i.name for i in graph_view.get_nodes()]) + self.assertTrue("Add0" not in [i.name for i in graph_view.get_nodes()]) + self.assertTrue("MatMul1" not in [i.name for i in graph_view.get_nodes()]) + self.assertTrue("Add1" not in [i.name for i in graph_view.get_nodes()]) + + self.assertTrue("W0" in [i.name for i in graph_view.get_nodes()]) + self.assertTrue("B0" in [i.name for i in graph_view.get_nodes()]) + self.assertTrue("W1" in [i.name for i in graph_view.get_nodes()]) + self.assertTrue("B1" in [i.name for i in graph_view.get_nodes()]) + # TODO : Vérifier que FC bien crée if __name__ == '__main__': unittest.main() diff --git a/include/aidge/utils/Recipies.hpp b/include/aidge/utils/Recipies.hpp index 68bcf17ace039349ddc95f40a324de954763d663..894e56fae2e9c2f6bcf11e4e76a433f5c8058080 100644 --- a/include/aidge/utils/Recipies.hpp +++ b/include/aidge/utils/Recipies.hpp @@ -16,11 +16,54 @@ #include "aidge/graph/GraphView.hpp" namespace Aidge{ + +// FUSE MATMUL + ADD -> FC + +/** + * @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node. + * + * @param nodes Strict set of Node to merge. + */ void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes); +/** + * @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node. + * + * @param graphView Graph view to use graph matching on, in order to apply transfomrations. + */ void fuseMulAdd(std::shared_ptr<GraphView> graphView); + +// REMOVE FLATTEN + FC -> FC + +/** + * @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node. + * + * @param nodes Strict set of Node to merge. + */ void removeFlatten(std::set<std::shared_ptr<Node>> nodes); +/** + * @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node. + * + * @param graphView Graph view to use graph matching on, in order to apply transfomrations. + */ void removeFlatten(std::shared_ptr<GraphView> graphView); + +// FUSE BN + FC || CONV -> FC || CONV + +/** + * @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes. + * Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ + * + * @param nodes Strict set of Node to merge. + */ +void fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes); +/** + * @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes. + * Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ + * + * @param graphView Graph view to use graph matching on, in order to apply transfomrations. + */ +void fuseBatchNorm(std::shared_ptr<GraphView> graphView); } diff --git a/src/recipies/FuseMulAdd.cpp b/src/recipies/FuseMulAdd.cpp index 20ab34c7e921b44a5c17b65ed7a101d9a9c34a59..068e4122673e841038fee1c2a3f9b675d0e9fad5 100644 --- a/src/recipies/FuseMulAdd.cpp +++ b/src/recipies/FuseMulAdd.cpp @@ -20,14 +20,11 @@ #include "aidge/graph/Node.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/operator/GenericOperator.hpp" - +// Graph Regex +#include "aidge/graphmatching/GRegex.hpp" +#include "aidge/graphmatching/NodeRegex.hpp" using namespace Aidge; -/** - * @brief Merge MatMul and Add Node into FC. - * - * @param nodes Strict set of Node to merge. - */ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ // Fuse Mulmat & Add into FC // Inputs : old nodes (pointers on mul & add) @@ -61,10 +58,12 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ // link weights & bias if (matmul->getParents(1)==nullptr) { matmul->getParents(0)->addChild(fc, 0, 1); + printf("Matmul out[1] == nullptr !\n"); } else { + printf("Matmul out[1] != nullptr !\n"); if (matmul->getParents(0)!=nullptr) matmul->getParents(0)->addChild(fc, 0, 0); - matmul->getParents(1)->addChild(fc, 0, 1); + matmul->input(1).first->addChild(fc, 0, 1); } (producer_add_bias.first)->addChild(fc,0,2); @@ -79,3 +78,17 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ } +void Aidge::fuseMulAdd(std::shared_ptr<GraphView> graphView){ + + std::map<std::string,NodeRegex*> nodesRegex ; + nodesRegex["MatMul"] = new NodeRegex("MatMul"); + nodesRegex["Add"] = new NodeRegex("Add"); + std::vector<std::string> seqRegex; + seqRegex.push_back("MatMul -> Add;"); + GRegex GReg(nodesRegex, seqRegex); + Match matches = GReg.match(graphView); + std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); + for (size_t i = 0; i < matches.getNbMatch(); ++i) { + fuseMulAdd(matchNodes[i]); + } +}