diff --git a/aidge_core/unit_tests/test_recipies.py b/aidge_core/unit_tests/test_recipies.py index 52ee65cbee77f8c9b2a6542e44f1362260a065c9..dae720da15d976957179f60ea84aff8cea96d210 100644 --- a/aidge_core/unit_tests/test_recipies.py +++ b/aidge_core/unit_tests/test_recipies.py @@ -22,8 +22,8 @@ class test_recipies(unittest.TestCase): def test_remove_dropout(self): graph_view = aidge_core.sequential([ - aidge_core.GenericOperator("Conv", 1, 0, 1, "Conv0"), - aidge_core.GenericOperator("Dropout", 1, 1, 1, name="Dropout0") + aidge_core.GenericOperator("Conv", 1, 0, 1, "Conv0"); + aidge_core.GenericOperator("Dropout", 1, 0, 1, name="Dropout0") ]) old_nodes = graph_view.get_nodes() aidge_core.remove_dropout(graph_view) diff --git a/src/recipies/RemoveDropout.cpp b/src/recipies/RemoveDropout.cpp index a159f3b85079a54dc140b1bdaf2d3d9fd21528be..1dedac8f19e6ec6b4b1f6dabb6bd3e9b8c759def 100644 --- a/src/recipies/RemoveDropout.cpp +++ b/src/recipies/RemoveDropout.cpp @@ -10,6 +10,7 @@ ********************************************************************************/ #include <memory> +#include <iostream> #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" @@ -47,7 +48,7 @@ namespace Aidge { void removeDropout(std::shared_ptr<GraphView> graphView){ std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); regex->setNodeKey("Dropout","getType($) =='Dropout'"); - regex->addQuery("Dropout"); + regex->addQuery("Dropout#"); for (const auto& solution : regex->match(graphView)) { removeDropout(solution);