From 515ef7d501b7a612b6caedcff5243c129c82a1bf Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Thu, 5 Sep 2024 08:09:33 +0000
Subject: [PATCH] Applied proposed change in
 https://gitlab.eclipse.org/eclipse/aidge/aidge_core/-/merge_requests/189#note_2814405.

---
 src/recipes/RemoveFlatten.cpp | 19 +++----------------
 1 file changed, 3 insertions(+), 16 deletions(-)

diff --git a/src/recipes/RemoveFlatten.cpp b/src/recipes/RemoveFlatten.cpp
index 8e59ea090..bf80ab517 100644
--- a/src/recipes/RemoveFlatten.cpp
+++ b/src/recipes/RemoveFlatten.cpp
@@ -22,28 +22,15 @@
 
 
 namespace Aidge {
-
-    void removeFlatten(const std::set<NodePtr>& solution){
-        std::set<NodePtr> flattenNodes {};
-        for (const auto& node : solution) {
-            if (node->type() == "Flatten"){
-                printf("Flatten found.\n");
-                flattenNodes.insert(node);
-            }
-            else if (! (node->type() == "MatMul" || node->type() == "FC")){
-                AIDGE_THROW_OR_ABORT(std::runtime_error, "Node of type {} is not MatMul nor FC, an error during GraphMatching occured !", node->type());
-            }
-        }
-        GraphView::replace(flattenNodes, {});
-    }
-
     void removeFlatten(std::shared_ptr<GraphView> graphView){
         const auto matches = SinglePassGraphMatching(graphView).match(
             "(FC|MatMul)<-(Flatten)+"
         );
 
         for (const auto& solution : matches) {
-            removeFlatten(solution.graph->getNodes());
+            auto flattenNodes(solution.graph->getNodes());
+            flattenNodes.erase(solution.graph->rootNode());
+            GraphView::replace(flattenNodes, {});
         }
     }
 }
-- 
GitLab