diff --git a/src/recipes/RemoveFlatten.cpp b/src/recipes/RemoveFlatten.cpp index 8e59ea090c55ddb422a30acd1c7b9348ee1025c8..bf80ab51749953a5b72d0e01f186265fdbb72e81 100644 --- a/src/recipes/RemoveFlatten.cpp +++ b/src/recipes/RemoveFlatten.cpp @@ -22,28 +22,15 @@ namespace Aidge { - - void removeFlatten(const std::set<NodePtr>& solution){ - std::set<NodePtr> flattenNodes {}; - for (const auto& node : solution) { - if (node->type() == "Flatten"){ - printf("Flatten found.\n"); - flattenNodes.insert(node); - } - else if (! (node->type() == "MatMul" || node->type() == "FC")){ - AIDGE_THROW_OR_ABORT(std::runtime_error, "Node of type {} is not MatMul nor FC, an error during GraphMatching occured !", node->type()); - } - } - GraphView::replace(flattenNodes, {}); - } - void removeFlatten(std::shared_ptr<GraphView> graphView){ const auto matches = SinglePassGraphMatching(graphView).match( "(FC|MatMul)<-(Flatten)+" ); for (const auto& solution : matches) { - removeFlatten(solution.graph->getNodes()); + auto flattenNodes(solution.graph->getNodes()); + flattenNodes.erase(solution.graph->rootNode()); + GraphView::replace(flattenNodes, {}); } } }