Skip to content
Snippets Groups Projects
Commit 6fd58bac authored by Maxence Naud's avatar Maxence Naud
Browse files

[Fix] FuseMulAdd python test and fusemuladd function

parent e98dd055
No related branches found
No related tags found
1 merge request!9Fuse bn
This commit is part of merge request !9. Comments created here will be created in the context of that merge request.
......@@ -60,15 +60,15 @@ class test_recipies(unittest.TestCase):
aidge_core.fuse_mul_add(graph_view)
self.assertTrue(len(graph_view.get_nodes()) == len(old_nodes) - 2)
self.assertTrue("MatMul0" not in [i.name for i in graph_view.get_nodes()])
self.assertTrue("Add0" not in [i.name for i in graph_view.get_nodes()])
self.assertTrue("MatMul1" not in [i.name for i in graph_view.get_nodes()])
self.assertTrue("Add1" not in [i.name for i in graph_view.get_nodes()])
self.assertTrue("W0" in [i.name for i in graph_view.get_nodes()])
self.assertTrue("B0" in [i.name for i in graph_view.get_nodes()])
self.assertTrue("W1" in [i.name for i in graph_view.get_nodes()])
self.assertTrue("B1" in [i.name for i in graph_view.get_nodes()])
self.assertTrue("MatMul0" not in [i.name() for i in graph_view.get_nodes()])
self.assertTrue("Add0" not in [i.name() for i in graph_view.get_nodes()])
self.assertTrue("MatMul1" not in [i.name() for i in graph_view.get_nodes()])
self.assertTrue("Add1" not in [i.name() for i in graph_view.get_nodes()])
self.assertTrue("W0" in [i.name() for i in graph_view.get_nodes()])
self.assertTrue("B0" in [i.name() for i in graph_view.get_nodes()])
self.assertTrue("W1" in [i.name() for i in graph_view.get_nodes()])
self.assertTrue("B1" in [i.name() for i in graph_view.get_nodes()])
# TODO : Vérifier que FC bien crée
if __name__ == '__main__':
......
......@@ -58,9 +58,9 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
// link weights & bias
if (matmul->getParent(1)==nullptr) {
matmul->getParent(0)->addChild(fc, 0, 1);
printf("Matmul out[1] == nullptr !\n");
printf("MatMul out[1] == nullptr !\n");
} else {
printf("Matmul out[1] != nullptr !\n");
printf("MatMul out[1] != nullptr !\n");
if (matmul->getParent(0)!=nullptr)
matmul->getParent(0)->addChild(fc, 0, 0);
matmul->input(1).first->addChild(fc, 0, 1);
......@@ -73,7 +73,7 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
// Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview
// Maybe create a central mechanism to update automatically all graph views rather than each node have graphview presence memory ?
auto nodeToReplace = std::make_shared<GraphView>();
nodeToReplace->add(nodes);
nodeToReplace->add(nodes, false);
nodeToReplace->replaceWith({fc});
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment