diff --git a/src/recipies/RemoveFlatten.cpp b/src/recipies/RemoveFlatten.cpp index 5bba4a71622a4d5df33970ee57b0459ed06f93d3..bfb4c09fd0202e4aff020764722bba7afe32cb5d 100644 --- a/src/recipies/RemoveFlatten.cpp +++ b/src/recipies/RemoveFlatten.cpp @@ -22,16 +22,25 @@ namespace Aidge { void removeFlatten(std::set<std::shared_ptr<Node>> nodes) { + assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); + std::shared_ptr<Node> fc; + for (const auto& element : nodes) { + assert((element->type() == "FC" || element->type() == "Flatten") && "Wrong type for the nodes to replace"); + if (element->type() == "FC"){ + fc = element; + } + } auto g = std::make_shared<GraphView>(); g->add(std::set<std::shared_ptr<Node>>({nodes})); - g->replaceWith({}); + g->replaceWith({fc}); } void removeFlatten(std::shared_ptr<GraphView> graphView){ std::map<std::string,NodeRegex*> nodesRegex ; nodesRegex["Flatten"] = new NodeRegex("Flatten"); + nodesRegex["FC"] = new NodeRegex("FC"); std::vector<std::string> seqRegex; - seqRegex.push_back("Flatten;"); + seqRegex.push_back("Flatten->FC;"); GRegex GReg(nodesRegex, seqRegex); Match matches = GReg.match(graphView); std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();