diff --git a/aidge_core/unit_tests/test_recipies.py b/aidge_core/unit_tests/test_recipies.py index 15ec22d3eddebb3d5475b82d861eb96b62f5a0b7..754907443530f7e73d1e10ed9549d0c8eb78a011 100644 --- a/aidge_core/unit_tests/test_recipies.py +++ b/aidge_core/unit_tests/test_recipies.py @@ -60,15 +60,15 @@ class test_recipies(unittest.TestCase): aidge_core.fuse_mul_add(graph_view) self.assertTrue(len(graph_view.get_nodes()) == len(old_nodes) - 2) - self.assertTrue("MatMul0" not in [i.name for i in graph_view.get_nodes()]) - self.assertTrue("Add0" not in [i.name for i in graph_view.get_nodes()]) - self.assertTrue("MatMul1" not in [i.name for i in graph_view.get_nodes()]) - self.assertTrue("Add1" not in [i.name for i in graph_view.get_nodes()]) - - self.assertTrue("W0" in [i.name for i in graph_view.get_nodes()]) - self.assertTrue("B0" in [i.name for i in graph_view.get_nodes()]) - self.assertTrue("W1" in [i.name for i in graph_view.get_nodes()]) - self.assertTrue("B1" in [i.name for i in graph_view.get_nodes()]) + self.assertTrue("MatMul0" not in [i.name() for i in graph_view.get_nodes()]) + self.assertTrue("Add0" not in [i.name() for i in graph_view.get_nodes()]) + self.assertTrue("MatMul1" not in [i.name() for i in graph_view.get_nodes()]) + self.assertTrue("Add1" not in [i.name() for i in graph_view.get_nodes()]) + + self.assertTrue("W0" in [i.name() for i in graph_view.get_nodes()]) + self.assertTrue("B0" in [i.name() for i in graph_view.get_nodes()]) + self.assertTrue("W1" in [i.name() for i in graph_view.get_nodes()]) + self.assertTrue("B1" in [i.name() for i in graph_view.get_nodes()]) # TODO : Vérifier que FC bien crée if __name__ == '__main__': diff --git a/src/recipies/FuseMulAdd.cpp b/src/recipies/FuseMulAdd.cpp index c31578d91a5e08d713d47ba698eddae6f4c2fc68..1de79890f9b597c4baff7427e01d7217f9695a44 100644 --- a/src/recipies/FuseMulAdd.cpp +++ b/src/recipies/FuseMulAdd.cpp @@ -58,9 +58,9 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ // link weights & bias if (matmul->getParent(1)==nullptr) { matmul->getParent(0)->addChild(fc, 0, 1); - printf("Matmul out[1] == nullptr !\n"); + printf("MatMul out[1] == nullptr !\n"); } else { - printf("Matmul out[1] != nullptr !\n"); + printf("MatMul out[1] != nullptr !\n"); if (matmul->getParent(0)!=nullptr) matmul->getParent(0)->addChild(fc, 0, 0); matmul->input(1).first->addChild(fc, 0, 1); @@ -73,7 +73,7 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ // Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview // Maybe create a central mechanism to update automatically all graph views rather than each node have graphview presence memory ? auto nodeToReplace = std::make_shared<GraphView>(); - nodeToReplace->add(nodes); + nodeToReplace->add(nodes, false); nodeToReplace->replaceWith({fc}); }