diff --git a/aidge_core/unit_tests/test_recipies.py b/aidge_core/unit_tests/test_recipies.py new file mode 100644 index 0000000000000000000000000000000000000000..96ed5c42ce1d0dc557f8b9c0f12178e4b8a874dd --- /dev/null +++ b/aidge_core/unit_tests/test_recipies.py @@ -0,0 +1,41 @@ +""" +Copyright (c) 2023 CEA-List + +This program and the accompanying materials are made available under the +terms of the Eclipse Public License 2.0 which is available at +http://www.eclipse.org/legal/epl-2.0. + +SPDX-License-Identifier: EPL-2.0 +""" + +import unittest +import aidge_core + +class test_parameters(unittest.TestCase): + """Very basic test to make sure the python APi is not broken. + Can be remove in later stage of the developpement. + """ + def setUp(self): + pass + + def tearDown(self): + pass + + def test_conv(self): + graph_view = aidge_core.sequential([ + aidge_core.GenericOperator("Flatten", 1, 1, 1, name="Flatten0"), + aidge_core.FC(50, name='0') + ]) + old_nodes = graph_view.get_nodes() + aidge_core.remove_flatten(graph_view) + self.assertTrue(len(graph_view.get_nodes()) == len(old_nodes) - 1) + self.assertTrue("Flatten0" not in [i.name for i in graph_view.get_nodes()]) + + self.assertTrue(all([i in old_nodes for i in graph_view.get_nodes()])) + + +if __name__ == '__main__': + unittest.main() + + + diff --git a/src/recipies/RemoveFlatten.cpp b/src/recipies/RemoveFlatten.cpp index bfb4c09fd0202e4aff020764722bba7afe32cb5d..9096c107ba505f5f18993a761273552408db721b 100644 --- a/src/recipies/RemoveFlatten.cpp +++ b/src/recipies/RemoveFlatten.cpp @@ -23,16 +23,17 @@ namespace Aidge { void removeFlatten(std::set<std::shared_ptr<Node>> nodes) { assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); - std::shared_ptr<Node> fc; + std::shared_ptr<Node> flatten; for (const auto& element : nodes) { assert((element->type() == "FC" || element->type() == "Flatten") && "Wrong type for the nodes to replace"); - if (element->type() == "FC"){ - fc = element; + if (element->type() == "Flatten"){ + flatten = element; } } auto g = std::make_shared<GraphView>(); - g->add(std::set<std::shared_ptr<Node>>({nodes})); - g->replaceWith({fc}); + // TODO : avoid using replace_with and use a remove method instead + g->add(std::set<std::shared_ptr<Node>>({flatten})); + g->replaceWith({}); } void removeFlatten(std::shared_ptr<GraphView> graphView){