Skip to content
Snippets Groups Projects
Commit 8024b47f authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Add unittest for fuse Mul + Add -> FC and a method with graphView as input.

parent 0fbf84e7
No related branches found
No related tags found
No related merge requests found
......@@ -21,7 +21,7 @@ class test_parameters(unittest.TestCase):
def tearDown(self):
pass
def test_conv(self):
def test_remove_flatten(self):
graph_view = aidge_core.sequential([
aidge_core.GenericOperator("Flatten", 1, 1, 1, name="Flatten0"),
aidge_core.FC(50, name='0')
......@@ -33,6 +33,45 @@ class test_parameters(unittest.TestCase):
self.assertTrue(all([i in old_nodes for i in graph_view.get_nodes()]))
def test_fuse_matmul_add(self):
matmul0 = aidge_core.GenericOperator("MatMul", 1, 2, 1, name="MatMul0")
add0 = aidge_core.Add(name="Add0")
matmul1 = aidge_core.GenericOperator("MatMul", 1, 2, 1, name="MatMul1")
add1 = aidge_core.Add(name="Add1")
graph_view = aidge_core.sequential([matmul0, add0, matmul1, add1])
w0 = aidge_core.Producer([1, 1], name="W0")
w0.add_child(matmul0, 0, 1)
graph_view.add(w0)
b0 = aidge_core.Producer([1], name="B0")
b0.add_child(add0, 0, 1)
graph_view.add(b0)
w1 = aidge_core.Producer([1, 1], name="W1")
w1.add_child(matmul1, 0, 1)
graph_view.add(w1)
b1 = aidge_core.Producer([1], name="B1")
b1.add_child(add1, 0, 1)
graph_view.add(b1)
graph_view.save("matmul")
old_nodes = graph_view.get_nodes()
aidge_core.fuse_mul_add(graph_view)
self.assertTrue(len(graph_view.get_nodes()) == len(old_nodes) - 2)
self.assertTrue("MatMul0" not in [i.name for i in graph_view.get_nodes()])
self.assertTrue("Add0" not in [i.name for i in graph_view.get_nodes()])
self.assertTrue("MatMul1" not in [i.name for i in graph_view.get_nodes()])
self.assertTrue("Add1" not in [i.name for i in graph_view.get_nodes()])
self.assertTrue("W0" in [i.name for i in graph_view.get_nodes()])
self.assertTrue("B0" in [i.name for i in graph_view.get_nodes()])
self.assertTrue("W1" in [i.name for i in graph_view.get_nodes()])
self.assertTrue("B1" in [i.name for i in graph_view.get_nodes()])
# TODO : Vérifier que FC bien crée
if __name__ == '__main__':
unittest.main()
......
......@@ -16,11 +16,54 @@
#include "aidge/graph/GraphView.hpp"
namespace Aidge{
// FUSE MATMUL + ADD -> FC
/**
* @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node.
*
* @param nodes Strict set of Node to merge.
*/
void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes);
/**
* @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node.
*
* @param graphView Graph view to use graph matching on, in order to apply transfomrations.
*/
void fuseMulAdd(std::shared_ptr<GraphView> graphView);
// REMOVE FLATTEN + FC -> FC
/**
* @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node.
*
* @param nodes Strict set of Node to merge.
*/
void removeFlatten(std::set<std::shared_ptr<Node>> nodes);
/**
* @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node.
*
* @param graphView Graph view to use graph matching on, in order to apply transfomrations.
*/
void removeFlatten(std::shared_ptr<GraphView> graphView);
// FUSE BN + FC || CONV -> FC || CONV
/**
* @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes.
* Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/
*
* @param nodes Strict set of Node to merge.
*/
void fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes);
/**
* @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes.
* Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/
*
* @param graphView Graph view to use graph matching on, in order to apply transfomrations.
*/
void fuseBatchNorm(std::shared_ptr<GraphView> graphView);
}
......
......@@ -20,14 +20,11 @@
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/GenericOperator.hpp"
// Graph Regex
#include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
using namespace Aidge;
/**
* @brief Merge MatMul and Add Node into FC.
*
* @param nodes Strict set of Node to merge.
*/
void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
// Fuse Mulmat & Add into FC
// Inputs : old nodes (pointers on mul & add)
......@@ -61,10 +58,12 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
// link weights & bias
if (matmul->getParents(1)==nullptr) {
matmul->getParents(0)->addChild(fc, 0, 1);
printf("Matmul out[1] == nullptr !\n");
} else {
printf("Matmul out[1] != nullptr !\n");
if (matmul->getParents(0)!=nullptr)
matmul->getParents(0)->addChild(fc, 0, 0);
matmul->getParents(1)->addChild(fc, 0, 1);
matmul->input(1).first->addChild(fc, 0, 1);
}
(producer_add_bias.first)->addChild(fc,0,2);
......@@ -79,3 +78,17 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
}
void Aidge::fuseMulAdd(std::shared_ptr<GraphView> graphView){
std::map<std::string,NodeRegex*> nodesRegex ;
nodesRegex["MatMul"] = new NodeRegex("MatMul");
nodesRegex["Add"] = new NodeRegex("Add");
std::vector<std::string> seqRegex;
seqRegex.push_back("MatMul -> Add;");
GRegex GReg(nodesRegex, seqRegex);
Match matches = GReg.match(graphView);
std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();
for (size_t i = 0; i < matches.getNbMatch(); ++i) {
fuseMulAdd(matchNodes[i]);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment