diff --git a/include/aidge/graph/Matching.hpp b/include/aidge/graph/Matching.hpp index 354eaa5752a933bbd755531edf5c93f15f995774..bc36bfb282356cf23df8875031f300a1f12e3153 100644 --- a/include/aidge/graph/Matching.hpp +++ b/include/aidge/graph/Matching.hpp @@ -14,6 +14,7 @@ #include <map> #include <memory> +#include <set> #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" @@ -40,9 +41,13 @@ public: }; struct MatchingResult { - std::shared_ptr<GraphView> graph; - std::map<std::string, std::map<std::string, NodePtr>> anchors; - NodePtr startNode; + // Mutable is required to allow modifying MatchingResult members with a std::set + // iterator. Any change should not modify the set ordering. + // We use graph->rootNode() as the std::set key, which is garanteed + // to never change after insertion! + mutable std::shared_ptr<GraphView> graph; + mutable std::map<std::string, std::map<std::string, NodePtr>> anchors; + mutable NodePtr startNode; MatchingResult() { graph = std::make_shared<GraphView>(); @@ -115,7 +120,7 @@ public: * * QUERY = SEQ | NODE_OR_BLOCK (';' (SEQ | NODE_OR_BLOCK))* */ - std::vector<MatchingResult> match(const std::string& query); + std::set<MatchingResult> match(const std::string& query); inline void addNodeLambda(const std::string& name, bool(func)(const NodePtr&)) { mLambda[name] = func; @@ -128,28 +133,28 @@ private: /** * NODE_OR_BLOCK = BLOCK | NODE */ - bool matchNodeOrBlock(Context& ctx, std::vector<MatchingResult>& matches); + bool matchNodeOrBlock(Context& ctx, std::set<MatchingResult>& matches); /** * QUANTIFIER = '?' | '*' | '+' | ('{' [0-9]+ '}') * BLOCK = '(' SEQ | PAR | BLOCK | ALT | NODE ')' QUANTIFIER? */ - bool matchBlock(Context& ctx, std::vector<MatchingResult>& matches); + bool matchBlock(Context& ctx, std::set<MatchingResult>& matches); /** * SEQ = NODE_OR_BLOCK (EDGE NODE_OR_BLOCK)+ */ - bool matchSequence(Context& ctx, std::vector<MatchingResult>& matches); + bool matchSequence(Context& ctx, std::set<MatchingResult>& matches); /** * PAR = NODE_OR_BLOCK ('&' NODE_OR_BLOCK)+ */ - bool matchParallel(Context& ctx, std::vector<MatchingResult>& matches); + bool matchParallel(Context& ctx, std::set<MatchingResult>& matches); /** * ALT = NODE_OR_BLOCK ('|' NODE_OR_BLOCK)+ */ - bool matchAlternative(Context& ctx, std::vector<MatchingResult>& matches); + bool matchAlternative(Context& ctx, std::set<MatchingResult>& matches); /** * IO_INDEX_ANY = '*' @@ -158,7 +163,7 @@ private: * PARENT_EDGE = '<' (IO_INDEX '-')? IO_INDEX? ('-' | '~') * EDGE = CHILD_EDGE | PARENT_EDGE */ - bool matchEdge(Context& ctx, std::vector<MatchingResult>& matches); + bool matchEdge(Context& ctx, std::set<MatchingResult>& matches); /** * TYPE = [A-Za-z0-9_]+ @@ -166,7 +171,7 @@ private: * LAMBDA = [A-Za-z0-9_]+ * NODE = (TYPE | '.') ('#' ANCHOR)? ('[' LAMBDA ']')? */ - bool matchNode(Context& ctx, std::vector<MatchingResult>& matches); + bool matchNode(Context& ctx, std::set<MatchingResult>& matches); inline void removeWhiteSpace(std::string& str) { str.erase(str.begin(), @@ -175,6 +180,10 @@ private: [](char c) { return !std::isspace(c); })); } }; + +inline bool operator<(const Aidge::SinglePassGraphMatching::MatchingResult& lhs, const Aidge::SinglePassGraphMatching::MatchingResult& rhs) { + return lhs.graph->rootNode() < rhs.graph->rootNode(); +} } // namespace Aidge #endif /* AIDGE_CORE_GRAPH_MATCHING_H_ */ diff --git a/src/graph/Matching.cpp b/src/graph/Matching.cpp index aae3175258760375782363be620add84e2abcee3..c855119a7736e043b6e10b1030a08b750985a002 100644 --- a/src/graph/Matching.cpp +++ b/src/graph/Matching.cpp @@ -2,10 +2,10 @@ #include <fmt/color.h> -std::vector<Aidge::SinglePassGraphMatching::MatchingResult> Aidge::SinglePassGraphMatching::match(const std::string& query) { +std::set<Aidge::SinglePassGraphMatching::MatchingResult> Aidge::SinglePassGraphMatching::match(const std::string& query) { Context ctx; ctx.query = query; - std::vector<MatchingResult> matches; + std::set<MatchingResult> matches; while (matchSequence(ctx, matches) || matchNodeOrBlock(ctx, matches)) { removeWhiteSpace(ctx.query); @@ -25,7 +25,7 @@ std::vector<Aidge::SinglePassGraphMatching::MatchingResult> Aidge::SinglePassGra return matches; } -bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::vector<MatchingResult>& matches) { +bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::set<MatchingResult>& matches) { auto newCtx = ctx; Log::debug("{}node-or-block", std::string(2*newCtx.depth, ' ')); auto newMatches = matches; @@ -49,7 +49,7 @@ bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::vector< } if (!found) { - newMatches.push_back(match); + newMatches.insert(match); } } @@ -98,7 +98,7 @@ bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::vector< } if (matchMore) { - std::vector<MatchingResult> additionalMatches; + std::set<MatchingResult> additionalMatches; do { auto additionalCtx = ctx; @@ -115,7 +115,9 @@ bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::vector< for (const auto& additionalMatch : additionalMatches) { for (auto& match : newMatches) { if (match.graph->rootNode() == additionalMatch.graph->rootNode()) { - match = additionalMatch; + match.graph = std::make_shared<GraphView>(*(additionalMatch.graph.get())); + match.anchors = additionalMatch.anchors; + match.startNode = additionalMatch.startNode; break; } } @@ -132,7 +134,7 @@ bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::vector< return true; } -bool Aidge::SinglePassGraphMatching::matchBlock(Context& ctx, std::vector<MatchingResult>& matches) { +bool Aidge::SinglePassGraphMatching::matchBlock(Context& ctx, std::set<MatchingResult>& matches) { auto newCtx = ctx; Log::debug("{}block", std::string(2*newCtx.depth, ' ')); auto newMatches = matches; @@ -172,7 +174,7 @@ bool Aidge::SinglePassGraphMatching::matchBlock(Context& ctx, std::vector<Matchi return true; } -bool Aidge::SinglePassGraphMatching::matchSequence(Context& ctx, std::vector<MatchingResult>& matches) { +bool Aidge::SinglePassGraphMatching::matchSequence(Context& ctx, std::set<MatchingResult>& matches) { auto newCtx = ctx; Log::debug("{}sequence", std::string(2*newCtx.depth, ' ')); auto newMatches = matches; @@ -220,7 +222,7 @@ bool Aidge::SinglePassGraphMatching::matchSequence(Context& ctx, std::vector<Mat return true; } -bool Aidge::SinglePassGraphMatching::matchParallel(Context& ctx, std::vector<MatchingResult>& matches) { +bool Aidge::SinglePassGraphMatching::matchParallel(Context& ctx, std::set<MatchingResult>& matches) { auto newCtx = ctx; Log::debug("{}parallel", std::string(2*newCtx.depth, ' ')); ++newCtx.depth; @@ -274,11 +276,11 @@ bool Aidge::SinglePassGraphMatching::matchParallel(Context& ctx, std::vector<Mat return true; } -bool Aidge::SinglePassGraphMatching::matchAlternative(Context& ctx, std::vector<MatchingResult>& matches) { +bool Aidge::SinglePassGraphMatching::matchAlternative(Context& ctx, std::set<MatchingResult>& matches) { auto newCtx = ctx; Log::debug("{}alternative", std::string(2*newCtx.depth, ' ')); ++newCtx.depth; - std::vector<MatchingResult> newMatches; + std::set<MatchingResult> newMatches; auto altCtx = newCtx; auto altMatches = matches; @@ -287,7 +289,7 @@ bool Aidge::SinglePassGraphMatching::matchAlternative(Context& ctx, std::vector< return false; } newCtx.query = altCtx.query; - newMatches.insert(newMatches.end(), altMatches.begin(), altMatches.end()); + newMatches.insert(altMatches.begin(), altMatches.end()); bool found = false; while (true) { @@ -307,7 +309,7 @@ bool Aidge::SinglePassGraphMatching::matchAlternative(Context& ctx, std::vector< return false; } newCtx.query = altCtx.query; - newMatches.insert(newMatches.end(), altMatches.begin(), altMatches.end()); + newMatches.insert(altMatches.begin(), altMatches.end()); } if (!found) { @@ -321,7 +323,7 @@ bool Aidge::SinglePassGraphMatching::matchAlternative(Context& ctx, std::vector< return true; } -bool Aidge::SinglePassGraphMatching::matchEdge(Context& ctx, std::vector<MatchingResult>& /*matches*/) { +bool Aidge::SinglePassGraphMatching::matchEdge(Context& ctx, std::set<MatchingResult>& /*matches*/) { auto newCtx = ctx; Log::debug("{}edge", std::string(2*newCtx.depth, ' ')); @@ -419,7 +421,7 @@ bool Aidge::SinglePassGraphMatching::matchEdge(Context& ctx, std::vector<Matchin return true; } -bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<MatchingResult>& matches) { +bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::set<MatchingResult>& matches) { auto newCtx = ctx; Log::debug("{}node", std::string(2*newCtx.depth, ' ')); auto newMatches = matches; @@ -492,7 +494,7 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin result.anchors[type][anchor] = node; } result.startNode = node; - newMatches.push_back(result); + newMatches.insert(result); } } newCtx.firstSequence = false; @@ -501,28 +503,28 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin } else if (newCtx.firstNode) { // First node of a (new) sequence: it has to be an existing anchor - for (size_t i = 0; i < newMatches.size(); ) { - const auto anchors = newMatches[i].anchors[type]; + for (auto it = newMatches.begin(); it != newMatches.end(); ) { + const auto anchors = it->anchors[type]; const auto anchorNode = anchors.find(anchor); if (anchorNode != anchors.end()) { - newMatches[i].startNode = anchorNode->second; - ++i; + it->startNode = anchorNode->second; + ++it; } else { - newMatches.erase(newMatches.begin() + i); + it = newMatches.erase(it); } } Log::debug("{}anchor node {}{}, found: {}", std::string(2*newCtx.depth + 2, ' '), fmt::styled(type.empty() ? "." : type, fmt::fg(fmt::color::yellow)), anchor, newMatches.size()); } else { - for (size_t i = 0; i < newMatches.size(); ) { + for (auto it = newMatches.begin(); it != newMatches.end(); ) { bool found = false; if (newCtx.lookForChild) { const auto outputs = (newCtx.edgeLeftIdx != gk_IODefaultIndex) - ? std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>>(1, std::vector<std::pair<NodePtr, IOIndex_t>>(newMatches[i].startNode->output(newCtx.edgeLeftIdx))) - : newMatches[i].startNode->outputs(); + ? std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>>(1, std::vector<std::pair<NodePtr, IOIndex_t>>(it->startNode->output(newCtx.edgeLeftIdx))) + : it->startNode->outputs(); for (const auto& output : outputs) { if (newCtx.singleOutput && output.size() > 1) { @@ -534,17 +536,17 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin && (lambda.empty() || mLambda.at(lambda)(node.first)) && (newCtx.edgeRightIdx == gk_IODefaultIndex || node.second == newCtx.edgeRightIdx)) { - if (mGraph->inView(node.first) && !newMatches[i].graph->inView(node.first)) { - newMatches[i].graph->add(node.first, false); + if (mGraph->inView(node.first) && !it->graph->inView(node.first)) { + it->graph->add(node.first, false); if (!anchor.empty()) { - newMatches[i].anchors[type][anchor] = node.first; + it->anchors[type][anchor] = node.first; } - newMatches[i].startNode = node.first; + it->startNode = node.first; found = true; break; } - else if (!anchor.empty() && newMatches[i].anchors[type].find(anchor) != newMatches[i].anchors[type].end()) { - newMatches[i].startNode = node.first; + else if (!anchor.empty() && it->anchors[type].find(anchor) != it->anchors[type].end()) { + it->startNode = node.first; found = true; break; } @@ -558,8 +560,8 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin } else { const auto inputs = (newCtx.edgeLeftIdx != gk_IODefaultIndex) - ? std::vector<std::pair<NodePtr, IOIndex_t>>(1, newMatches[i].startNode->input(newCtx.edgeLeftIdx)) - : newMatches[i].startNode->inputs(); + ? std::vector<std::pair<NodePtr, IOIndex_t>>(1, it->startNode->input(newCtx.edgeLeftIdx)) + : it->startNode->inputs(); for (const auto& input : inputs) { if ((type.empty() || input.first->type() == type) @@ -570,17 +572,17 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin continue; } - if (mGraph->inView(input.first) && !newMatches[i].graph->inView(input.first)) { - newMatches[i].graph->add(input.first, false); + if (mGraph->inView(input.first) && !it->graph->inView(input.first)) { + it->graph->add(input.first, false); if (!anchor.empty()) { - newMatches[i].anchors[type][anchor] = input.first; + it->anchors[type][anchor] = input.first; } - newMatches[i].startNode = input.first; + it->startNode = input.first; found = true; break; } - else if (!anchor.empty() && newMatches[i].anchors[type].find(anchor) != newMatches[i].anchors[type].end()) { - newMatches[i].startNode = input.first; + else if (!anchor.empty() && it->anchors[type].find(anchor) != it->anchors[type].end()) { + it->startNode = input.first; found = true; break; } @@ -589,10 +591,10 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin } if (found) { - ++i; + ++it; } else { - newMatches.erase(newMatches.begin() + i); + it = newMatches.erase(it); } } diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp index c9d7faa8159949d0517331591b5f26200269dc11..8b665ab6283343ef51c34d9ce3ff07a507304150 100644 --- a/unit_tests/graph/Test_Matching.cpp +++ b/unit_tests/graph/Test_Matching.cpp @@ -25,7 +25,7 @@ using namespace Aidge; -void checkMatches(const std::vector<SinglePassGraphMatching::MatchingResult>& results, const std::map<std::string, std::set<std::string>>& expected) { +void checkMatches(const std::set<SinglePassGraphMatching::MatchingResult>& results, const std::map<std::string, std::set<std::string>>& expected) { REQUIRE(results.size() == expected.size()); for (const auto& result : results) {