From 515ef7d501b7a612b6caedcff5243c129c82a1bf Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Thu, 5 Sep 2024 08:09:33 +0000 Subject: [PATCH] Applied proposed change in https://gitlab.eclipse.org/eclipse/aidge/aidge_core/-/merge_requests/189#note_2814405. --- src/recipes/RemoveFlatten.cpp | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/src/recipes/RemoveFlatten.cpp b/src/recipes/RemoveFlatten.cpp index 8e59ea090..bf80ab517 100644 --- a/src/recipes/RemoveFlatten.cpp +++ b/src/recipes/RemoveFlatten.cpp @@ -22,28 +22,15 @@ namespace Aidge { - - void removeFlatten(const std::set<NodePtr>& solution){ - std::set<NodePtr> flattenNodes {}; - for (const auto& node : solution) { - if (node->type() == "Flatten"){ - printf("Flatten found.\n"); - flattenNodes.insert(node); - } - else if (! (node->type() == "MatMul" || node->type() == "FC")){ - AIDGE_THROW_OR_ABORT(std::runtime_error, "Node of type {} is not MatMul nor FC, an error during GraphMatching occured !", node->type()); - } - } - GraphView::replace(flattenNodes, {}); - } - void removeFlatten(std::shared_ptr<GraphView> graphView){ const auto matches = SinglePassGraphMatching(graphView).match( "(FC|MatMul)<-(Flatten)+" ); for (const auto& solution : matches) { - removeFlatten(solution.graph->getNodes()); + auto flattenNodes(solution.graph->getNodes()); + flattenNodes.erase(solution.graph->rootNode()); + GraphView::replace(flattenNodes, {}); } } } -- GitLab