From 0fbf84e71edb37e49ae68afc26fd2d66ac2e4847 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Wed, 13 Sep 2023 09:21:38 +0000
Subject: [PATCH] [RemoveFlatten] Add python unittest.

---
 aidge_core/unit_tests/test_recipies.py | 41 ++++++++++++++++++++++++++
 src/recipies/RemoveFlatten.cpp         | 11 +++----
 2 files changed, 47 insertions(+), 5 deletions(-)
 create mode 100644 aidge_core/unit_tests/test_recipies.py

diff --git a/aidge_core/unit_tests/test_recipies.py b/aidge_core/unit_tests/test_recipies.py
new file mode 100644
index 000000000..96ed5c42c
--- /dev/null
+++ b/aidge_core/unit_tests/test_recipies.py
@@ -0,0 +1,41 @@
+"""
+Copyright (c) 2023 CEA-List
+
+This program and the accompanying materials are made available under the
+terms of the Eclipse Public License 2.0 which is available at
+http://www.eclipse.org/legal/epl-2.0.
+
+SPDX-License-Identifier: EPL-2.0
+"""
+
+import unittest
+import aidge_core
+
+class test_parameters(unittest.TestCase):
+    """Very basic test to make sure the python APi is not broken.
+    Can be remove in later stage of the developpement.
+    """
+    def setUp(self):
+        pass
+
+    def tearDown(self):
+        pass
+
+    def test_conv(self):
+        graph_view = aidge_core.sequential([
+            aidge_core.GenericOperator("Flatten", 1, 1, 1, name="Flatten0"),
+            aidge_core.FC(50, name='0')
+        ])
+        old_nodes = graph_view.get_nodes()
+        aidge_core.remove_flatten(graph_view)
+        self.assertTrue(len(graph_view.get_nodes()) == len(old_nodes) - 1)
+        self.assertTrue("Flatten0" not in [i.name for i in graph_view.get_nodes()])
+
+        self.assertTrue(all([i in old_nodes for i in graph_view.get_nodes()]))
+
+
+if __name__ == '__main__':
+    unittest.main()
+
+
+
diff --git a/src/recipies/RemoveFlatten.cpp b/src/recipies/RemoveFlatten.cpp
index bfb4c09fd..9096c107b 100644
--- a/src/recipies/RemoveFlatten.cpp
+++ b/src/recipies/RemoveFlatten.cpp
@@ -23,16 +23,17 @@
 namespace Aidge {
     void removeFlatten(std::set<std::shared_ptr<Node>> nodes) {
         assert(nodes.size() == 2 && "Wrong number of nodes to replace\n");
-        std::shared_ptr<Node> fc;
+        std::shared_ptr<Node> flatten;
         for (const auto& element : nodes) {
             assert((element->type() == "FC" || element->type() == "Flatten") && "Wrong type for the nodes to replace");
-            if (element->type() == "FC"){
-                fc = element;
+            if (element->type() == "Flatten"){
+                flatten = element;
             }
         }
         auto g = std::make_shared<GraphView>();
-        g->add(std::set<std::shared_ptr<Node>>({nodes}));
-        g->replaceWith({fc});
+        // TODO : avoid using replace_with and use a remove method instead
+        g->add(std::set<std::shared_ptr<Node>>({flatten}));
+        g->replaceWith({});
     }
 
     void removeFlatten(std::shared_ptr<GraphView> graphView){
-- 
GitLab