From 6fd58bac66c86cc9e8a393ee5375133e8ef8320a Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Tue, 3 Oct 2023 12:33:03 +0000 Subject: [PATCH] [Fix] FuseMulAdd python test and fusemuladd function --- aidge_core/unit_tests/test_recipies.py | 18 +++++++++--------- src/recipies/FuseMulAdd.cpp | 6 +++--- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/aidge_core/unit_tests/test_recipies.py b/aidge_core/unit_tests/test_recipies.py index 15ec22d3e..754907443 100644 --- a/aidge_core/unit_tests/test_recipies.py +++ b/aidge_core/unit_tests/test_recipies.py @@ -60,15 +60,15 @@ class test_recipies(unittest.TestCase): aidge_core.fuse_mul_add(graph_view) self.assertTrue(len(graph_view.get_nodes()) == len(old_nodes) - 2) - self.assertTrue("MatMul0" not in [i.name for i in graph_view.get_nodes()]) - self.assertTrue("Add0" not in [i.name for i in graph_view.get_nodes()]) - self.assertTrue("MatMul1" not in [i.name for i in graph_view.get_nodes()]) - self.assertTrue("Add1" not in [i.name for i in graph_view.get_nodes()]) - - self.assertTrue("W0" in [i.name for i in graph_view.get_nodes()]) - self.assertTrue("B0" in [i.name for i in graph_view.get_nodes()]) - self.assertTrue("W1" in [i.name for i in graph_view.get_nodes()]) - self.assertTrue("B1" in [i.name for i in graph_view.get_nodes()]) + self.assertTrue("MatMul0" not in [i.name() for i in graph_view.get_nodes()]) + self.assertTrue("Add0" not in [i.name() for i in graph_view.get_nodes()]) + self.assertTrue("MatMul1" not in [i.name() for i in graph_view.get_nodes()]) + self.assertTrue("Add1" not in [i.name() for i in graph_view.get_nodes()]) + + self.assertTrue("W0" in [i.name() for i in graph_view.get_nodes()]) + self.assertTrue("B0" in [i.name() for i in graph_view.get_nodes()]) + self.assertTrue("W1" in [i.name() for i in graph_view.get_nodes()]) + self.assertTrue("B1" in [i.name() for i in graph_view.get_nodes()]) # TODO : Vérifier que FC bien crée if __name__ == '__main__': diff --git a/src/recipies/FuseMulAdd.cpp b/src/recipies/FuseMulAdd.cpp index c31578d91..1de79890f 100644 --- a/src/recipies/FuseMulAdd.cpp +++ b/src/recipies/FuseMulAdd.cpp @@ -58,9 +58,9 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ // link weights & bias if (matmul->getParent(1)==nullptr) { matmul->getParent(0)->addChild(fc, 0, 1); - printf("Matmul out[1] == nullptr !\n"); + printf("MatMul out[1] == nullptr !\n"); } else { - printf("Matmul out[1] != nullptr !\n"); + printf("MatMul out[1] != nullptr !\n"); if (matmul->getParent(0)!=nullptr) matmul->getParent(0)->addChild(fc, 0, 0); matmul->input(1).first->addChild(fc, 0, 1); @@ -73,7 +73,7 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ // Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview // Maybe create a central mechanism to update automatically all graph views rather than each node have graphview presence memory ? auto nodeToReplace = std::make_shared<GraphView>(); - nodeToReplace->add(nodes); + nodeToReplace->add(nodes, false); nodeToReplace->replaceWith({fc}); } -- GitLab