From 878de3161e0f5c910821e1539b1f754dc709e8bc Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Fri, 15 Dec 2023 14:48:09 +0100 Subject: [PATCH] fix Dropout Regex Request and test --- aidge_core/unit_tests/test_recipies.py | 4 ++-- src/recipies/RemoveDropout.cpp | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/aidge_core/unit_tests/test_recipies.py b/aidge_core/unit_tests/test_recipies.py index 52ee65cbe..dae720da1 100644 --- a/aidge_core/unit_tests/test_recipies.py +++ b/aidge_core/unit_tests/test_recipies.py @@ -22,8 +22,8 @@ class test_recipies(unittest.TestCase): def test_remove_dropout(self): graph_view = aidge_core.sequential([ - aidge_core.GenericOperator("Conv", 1, 0, 1, "Conv0"), - aidge_core.GenericOperator("Dropout", 1, 1, 1, name="Dropout0") + aidge_core.GenericOperator("Conv", 1, 0, 1, "Conv0"); + aidge_core.GenericOperator("Dropout", 1, 0, 1, name="Dropout0") ]) old_nodes = graph_view.get_nodes() aidge_core.remove_dropout(graph_view) diff --git a/src/recipies/RemoveDropout.cpp b/src/recipies/RemoveDropout.cpp index a159f3b85..1dedac8f1 100644 --- a/src/recipies/RemoveDropout.cpp +++ b/src/recipies/RemoveDropout.cpp @@ -10,6 +10,7 @@ ********************************************************************************/ #include <memory> +#include <iostream> #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" @@ -47,7 +48,7 @@ namespace Aidge { void removeDropout(std::shared_ptr<GraphView> graphView){ std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); regex->setNodeKey("Dropout","getType($) =='Dropout'"); - regex->addQuery("Dropout"); + regex->addQuery("Dropout#"); for (const auto& solution : regex->match(graphView)) { removeDropout(solution); -- GitLab