From f4eb185fb2b09c1430081c536ddc88071e0eb2f4 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Wed, 13 Sep 2023 08:36:31 +0000 Subject: [PATCH] [RemoveFlatten] Update recipies to remove flatten only if it is before an FC ! --- src/recipies/RemoveFlatten.cpp | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/recipies/RemoveFlatten.cpp b/src/recipies/RemoveFlatten.cpp index 5bba4a716..bfb4c09fd 100644 --- a/src/recipies/RemoveFlatten.cpp +++ b/src/recipies/RemoveFlatten.cpp @@ -22,16 +22,25 @@ namespace Aidge { void removeFlatten(std::set<std::shared_ptr<Node>> nodes) { + assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); + std::shared_ptr<Node> fc; + for (const auto& element : nodes) { + assert((element->type() == "FC" || element->type() == "Flatten") && "Wrong type for the nodes to replace"); + if (element->type() == "FC"){ + fc = element; + } + } auto g = std::make_shared<GraphView>(); g->add(std::set<std::shared_ptr<Node>>({nodes})); - g->replaceWith({}); + g->replaceWith({fc}); } void removeFlatten(std::shared_ptr<GraphView> graphView){ std::map<std::string,NodeRegex*> nodesRegex ; nodesRegex["Flatten"] = new NodeRegex("Flatten"); + nodesRegex["FC"] = new NodeRegex("FC"); std::vector<std::string> seqRegex; - seqRegex.push_back("Flatten;"); + seqRegex.push_back("Flatten->FC;"); GRegex GReg(nodesRegex, seqRegex); Match matches = GReg.match(graphView); std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); -- GitLab