Skip to content
Snippets Groups Projects
Commit 3ae83af7 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Use std::set instead of std::vector

parent e97fb46d
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!138Alternative graph matching
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <map> #include <map>
#include <memory> #include <memory>
#include <set>
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
...@@ -40,9 +41,13 @@ public: ...@@ -40,9 +41,13 @@ public:
}; };
struct MatchingResult { struct MatchingResult {
std::shared_ptr<GraphView> graph; // Mutable is required to allow modifying MatchingResult members with a std::set
std::map<std::string, std::map<std::string, NodePtr>> anchors; // iterator. Any change should not modify the set ordering.
NodePtr startNode; // We use graph->rootNode() as the std::set key, which is garanteed
// to never change after insertion!
mutable std::shared_ptr<GraphView> graph;
mutable std::map<std::string, std::map<std::string, NodePtr>> anchors;
mutable NodePtr startNode;
MatchingResult() { MatchingResult() {
graph = std::make_shared<GraphView>(); graph = std::make_shared<GraphView>();
...@@ -115,7 +120,7 @@ public: ...@@ -115,7 +120,7 @@ public:
* *
* QUERY = SEQ | NODE_OR_BLOCK (';' (SEQ | NODE_OR_BLOCK))* * QUERY = SEQ | NODE_OR_BLOCK (';' (SEQ | NODE_OR_BLOCK))*
*/ */
std::vector<MatchingResult> match(const std::string& query); std::set<MatchingResult> match(const std::string& query);
inline void addNodeLambda(const std::string& name, bool(func)(const NodePtr&)) { inline void addNodeLambda(const std::string& name, bool(func)(const NodePtr&)) {
mLambda[name] = func; mLambda[name] = func;
...@@ -128,28 +133,28 @@ private: ...@@ -128,28 +133,28 @@ private:
/** /**
* NODE_OR_BLOCK = BLOCK | NODE * NODE_OR_BLOCK = BLOCK | NODE
*/ */
bool matchNodeOrBlock(Context& ctx, std::vector<MatchingResult>& matches); bool matchNodeOrBlock(Context& ctx, std::set<MatchingResult>& matches);
/** /**
* QUANTIFIER = '?' | '*' | '+' | ('{' [0-9]+ '}') * QUANTIFIER = '?' | '*' | '+' | ('{' [0-9]+ '}')
* BLOCK = '(' SEQ | PAR | BLOCK | ALT | NODE ')' QUANTIFIER? * BLOCK = '(' SEQ | PAR | BLOCK | ALT | NODE ')' QUANTIFIER?
*/ */
bool matchBlock(Context& ctx, std::vector<MatchingResult>& matches); bool matchBlock(Context& ctx, std::set<MatchingResult>& matches);
/** /**
* SEQ = NODE_OR_BLOCK (EDGE NODE_OR_BLOCK)+ * SEQ = NODE_OR_BLOCK (EDGE NODE_OR_BLOCK)+
*/ */
bool matchSequence(Context& ctx, std::vector<MatchingResult>& matches); bool matchSequence(Context& ctx, std::set<MatchingResult>& matches);
/** /**
* PAR = NODE_OR_BLOCK ('&' NODE_OR_BLOCK)+ * PAR = NODE_OR_BLOCK ('&' NODE_OR_BLOCK)+
*/ */
bool matchParallel(Context& ctx, std::vector<MatchingResult>& matches); bool matchParallel(Context& ctx, std::set<MatchingResult>& matches);
/** /**
* ALT = NODE_OR_BLOCK ('|' NODE_OR_BLOCK)+ * ALT = NODE_OR_BLOCK ('|' NODE_OR_BLOCK)+
*/ */
bool matchAlternative(Context& ctx, std::vector<MatchingResult>& matches); bool matchAlternative(Context& ctx, std::set<MatchingResult>& matches);
/** /**
* IO_INDEX_ANY = '*' * IO_INDEX_ANY = '*'
...@@ -158,7 +163,7 @@ private: ...@@ -158,7 +163,7 @@ private:
* PARENT_EDGE = '<' (IO_INDEX '-')? IO_INDEX? ('-' | '~') * PARENT_EDGE = '<' (IO_INDEX '-')? IO_INDEX? ('-' | '~')
* EDGE = CHILD_EDGE | PARENT_EDGE * EDGE = CHILD_EDGE | PARENT_EDGE
*/ */
bool matchEdge(Context& ctx, std::vector<MatchingResult>& matches); bool matchEdge(Context& ctx, std::set<MatchingResult>& matches);
/** /**
* TYPE = [A-Za-z0-9_]+ * TYPE = [A-Za-z0-9_]+
...@@ -166,7 +171,7 @@ private: ...@@ -166,7 +171,7 @@ private:
* LAMBDA = [A-Za-z0-9_]+ * LAMBDA = [A-Za-z0-9_]+
* NODE = (TYPE | '.') ('#' ANCHOR)? ('[' LAMBDA ']')? * NODE = (TYPE | '.') ('#' ANCHOR)? ('[' LAMBDA ']')?
*/ */
bool matchNode(Context& ctx, std::vector<MatchingResult>& matches); bool matchNode(Context& ctx, std::set<MatchingResult>& matches);
inline void removeWhiteSpace(std::string& str) { inline void removeWhiteSpace(std::string& str) {
str.erase(str.begin(), str.erase(str.begin(),
...@@ -175,6 +180,10 @@ private: ...@@ -175,6 +180,10 @@ private:
[](char c) { return !std::isspace(c); })); [](char c) { return !std::isspace(c); }));
} }
}; };
inline bool operator<(const Aidge::SinglePassGraphMatching::MatchingResult& lhs, const Aidge::SinglePassGraphMatching::MatchingResult& rhs) {
return lhs.graph->rootNode() < rhs.graph->rootNode();
}
} // namespace Aidge } // namespace Aidge
#endif /* AIDGE_CORE_GRAPH_MATCHING_H_ */ #endif /* AIDGE_CORE_GRAPH_MATCHING_H_ */
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
#include <fmt/color.h> #include <fmt/color.h>
std::vector<Aidge::SinglePassGraphMatching::MatchingResult> Aidge::SinglePassGraphMatching::match(const std::string& query) { std::set<Aidge::SinglePassGraphMatching::MatchingResult> Aidge::SinglePassGraphMatching::match(const std::string& query) {
Context ctx; Context ctx;
ctx.query = query; ctx.query = query;
std::vector<MatchingResult> matches; std::set<MatchingResult> matches;
while (matchSequence(ctx, matches) || matchNodeOrBlock(ctx, matches)) { while (matchSequence(ctx, matches) || matchNodeOrBlock(ctx, matches)) {
removeWhiteSpace(ctx.query); removeWhiteSpace(ctx.query);
...@@ -25,7 +25,7 @@ std::vector<Aidge::SinglePassGraphMatching::MatchingResult> Aidge::SinglePassGra ...@@ -25,7 +25,7 @@ std::vector<Aidge::SinglePassGraphMatching::MatchingResult> Aidge::SinglePassGra
return matches; return matches;
} }
bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::vector<MatchingResult>& matches) { bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::set<MatchingResult>& matches) {
auto newCtx = ctx; auto newCtx = ctx;
Log::debug("{}node-or-block", std::string(2*newCtx.depth, ' ')); Log::debug("{}node-or-block", std::string(2*newCtx.depth, ' '));
auto newMatches = matches; auto newMatches = matches;
...@@ -49,7 +49,7 @@ bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::vector< ...@@ -49,7 +49,7 @@ bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::vector<
} }
if (!found) { if (!found) {
newMatches.push_back(match); newMatches.insert(match);
} }
} }
...@@ -98,7 +98,7 @@ bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::vector< ...@@ -98,7 +98,7 @@ bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::vector<
} }
if (matchMore) { if (matchMore) {
std::vector<MatchingResult> additionalMatches; std::set<MatchingResult> additionalMatches;
do { do {
auto additionalCtx = ctx; auto additionalCtx = ctx;
...@@ -115,7 +115,9 @@ bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::vector< ...@@ -115,7 +115,9 @@ bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::vector<
for (const auto& additionalMatch : additionalMatches) { for (const auto& additionalMatch : additionalMatches) {
for (auto& match : newMatches) { for (auto& match : newMatches) {
if (match.graph->rootNode() == additionalMatch.graph->rootNode()) { if (match.graph->rootNode() == additionalMatch.graph->rootNode()) {
match = additionalMatch; match.graph = std::make_shared<GraphView>(*(additionalMatch.graph.get()));
match.anchors = additionalMatch.anchors;
match.startNode = additionalMatch.startNode;
break; break;
} }
} }
...@@ -132,7 +134,7 @@ bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::vector< ...@@ -132,7 +134,7 @@ bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::vector<
return true; return true;
} }
bool Aidge::SinglePassGraphMatching::matchBlock(Context& ctx, std::vector<MatchingResult>& matches) { bool Aidge::SinglePassGraphMatching::matchBlock(Context& ctx, std::set<MatchingResult>& matches) {
auto newCtx = ctx; auto newCtx = ctx;
Log::debug("{}block", std::string(2*newCtx.depth, ' ')); Log::debug("{}block", std::string(2*newCtx.depth, ' '));
auto newMatches = matches; auto newMatches = matches;
...@@ -172,7 +174,7 @@ bool Aidge::SinglePassGraphMatching::matchBlock(Context& ctx, std::vector<Matchi ...@@ -172,7 +174,7 @@ bool Aidge::SinglePassGraphMatching::matchBlock(Context& ctx, std::vector<Matchi
return true; return true;
} }
bool Aidge::SinglePassGraphMatching::matchSequence(Context& ctx, std::vector<MatchingResult>& matches) { bool Aidge::SinglePassGraphMatching::matchSequence(Context& ctx, std::set<MatchingResult>& matches) {
auto newCtx = ctx; auto newCtx = ctx;
Log::debug("{}sequence", std::string(2*newCtx.depth, ' ')); Log::debug("{}sequence", std::string(2*newCtx.depth, ' '));
auto newMatches = matches; auto newMatches = matches;
...@@ -220,7 +222,7 @@ bool Aidge::SinglePassGraphMatching::matchSequence(Context& ctx, std::vector<Mat ...@@ -220,7 +222,7 @@ bool Aidge::SinglePassGraphMatching::matchSequence(Context& ctx, std::vector<Mat
return true; return true;
} }
bool Aidge::SinglePassGraphMatching::matchParallel(Context& ctx, std::vector<MatchingResult>& matches) { bool Aidge::SinglePassGraphMatching::matchParallel(Context& ctx, std::set<MatchingResult>& matches) {
auto newCtx = ctx; auto newCtx = ctx;
Log::debug("{}parallel", std::string(2*newCtx.depth, ' ')); Log::debug("{}parallel", std::string(2*newCtx.depth, ' '));
++newCtx.depth; ++newCtx.depth;
...@@ -274,11 +276,11 @@ bool Aidge::SinglePassGraphMatching::matchParallel(Context& ctx, std::vector<Mat ...@@ -274,11 +276,11 @@ bool Aidge::SinglePassGraphMatching::matchParallel(Context& ctx, std::vector<Mat
return true; return true;
} }
bool Aidge::SinglePassGraphMatching::matchAlternative(Context& ctx, std::vector<MatchingResult>& matches) { bool Aidge::SinglePassGraphMatching::matchAlternative(Context& ctx, std::set<MatchingResult>& matches) {
auto newCtx = ctx; auto newCtx = ctx;
Log::debug("{}alternative", std::string(2*newCtx.depth, ' ')); Log::debug("{}alternative", std::string(2*newCtx.depth, ' '));
++newCtx.depth; ++newCtx.depth;
std::vector<MatchingResult> newMatches; std::set<MatchingResult> newMatches;
auto altCtx = newCtx; auto altCtx = newCtx;
auto altMatches = matches; auto altMatches = matches;
...@@ -287,7 +289,7 @@ bool Aidge::SinglePassGraphMatching::matchAlternative(Context& ctx, std::vector< ...@@ -287,7 +289,7 @@ bool Aidge::SinglePassGraphMatching::matchAlternative(Context& ctx, std::vector<
return false; return false;
} }
newCtx.query = altCtx.query; newCtx.query = altCtx.query;
newMatches.insert(newMatches.end(), altMatches.begin(), altMatches.end()); newMatches.insert(altMatches.begin(), altMatches.end());
bool found = false; bool found = false;
while (true) { while (true) {
...@@ -307,7 +309,7 @@ bool Aidge::SinglePassGraphMatching::matchAlternative(Context& ctx, std::vector< ...@@ -307,7 +309,7 @@ bool Aidge::SinglePassGraphMatching::matchAlternative(Context& ctx, std::vector<
return false; return false;
} }
newCtx.query = altCtx.query; newCtx.query = altCtx.query;
newMatches.insert(newMatches.end(), altMatches.begin(), altMatches.end()); newMatches.insert(altMatches.begin(), altMatches.end());
} }
if (!found) { if (!found) {
...@@ -321,7 +323,7 @@ bool Aidge::SinglePassGraphMatching::matchAlternative(Context& ctx, std::vector< ...@@ -321,7 +323,7 @@ bool Aidge::SinglePassGraphMatching::matchAlternative(Context& ctx, std::vector<
return true; return true;
} }
bool Aidge::SinglePassGraphMatching::matchEdge(Context& ctx, std::vector<MatchingResult>& /*matches*/) { bool Aidge::SinglePassGraphMatching::matchEdge(Context& ctx, std::set<MatchingResult>& /*matches*/) {
auto newCtx = ctx; auto newCtx = ctx;
Log::debug("{}edge", std::string(2*newCtx.depth, ' ')); Log::debug("{}edge", std::string(2*newCtx.depth, ' '));
...@@ -419,7 +421,7 @@ bool Aidge::SinglePassGraphMatching::matchEdge(Context& ctx, std::vector<Matchin ...@@ -419,7 +421,7 @@ bool Aidge::SinglePassGraphMatching::matchEdge(Context& ctx, std::vector<Matchin
return true; return true;
} }
bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<MatchingResult>& matches) { bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::set<MatchingResult>& matches) {
auto newCtx = ctx; auto newCtx = ctx;
Log::debug("{}node", std::string(2*newCtx.depth, ' ')); Log::debug("{}node", std::string(2*newCtx.depth, ' '));
auto newMatches = matches; auto newMatches = matches;
...@@ -492,7 +494,7 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin ...@@ -492,7 +494,7 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin
result.anchors[type][anchor] = node; result.anchors[type][anchor] = node;
} }
result.startNode = node; result.startNode = node;
newMatches.push_back(result); newMatches.insert(result);
} }
} }
newCtx.firstSequence = false; newCtx.firstSequence = false;
...@@ -501,28 +503,28 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin ...@@ -501,28 +503,28 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin
} }
else if (newCtx.firstNode) { else if (newCtx.firstNode) {
// First node of a (new) sequence: it has to be an existing anchor // First node of a (new) sequence: it has to be an existing anchor
for (size_t i = 0; i < newMatches.size(); ) { for (auto it = newMatches.begin(); it != newMatches.end(); ) {
const auto anchors = newMatches[i].anchors[type]; const auto anchors = it->anchors[type];
const auto anchorNode = anchors.find(anchor); const auto anchorNode = anchors.find(anchor);
if (anchorNode != anchors.end()) { if (anchorNode != anchors.end()) {
newMatches[i].startNode = anchorNode->second; it->startNode = anchorNode->second;
++i; ++it;
} }
else { else {
newMatches.erase(newMatches.begin() + i); it = newMatches.erase(it);
} }
} }
Log::debug("{}anchor node {}{}, found: {}", std::string(2*newCtx.depth + 2, ' '), fmt::styled(type.empty() ? "." : type, fmt::fg(fmt::color::yellow)), anchor, newMatches.size()); Log::debug("{}anchor node {}{}, found: {}", std::string(2*newCtx.depth + 2, ' '), fmt::styled(type.empty() ? "." : type, fmt::fg(fmt::color::yellow)), anchor, newMatches.size());
} }
else { else {
for (size_t i = 0; i < newMatches.size(); ) { for (auto it = newMatches.begin(); it != newMatches.end(); ) {
bool found = false; bool found = false;
if (newCtx.lookForChild) { if (newCtx.lookForChild) {
const auto outputs = (newCtx.edgeLeftIdx != gk_IODefaultIndex) const auto outputs = (newCtx.edgeLeftIdx != gk_IODefaultIndex)
? std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>>(1, std::vector<std::pair<NodePtr, IOIndex_t>>(newMatches[i].startNode->output(newCtx.edgeLeftIdx))) ? std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>>(1, std::vector<std::pair<NodePtr, IOIndex_t>>(it->startNode->output(newCtx.edgeLeftIdx)))
: newMatches[i].startNode->outputs(); : it->startNode->outputs();
for (const auto& output : outputs) { for (const auto& output : outputs) {
if (newCtx.singleOutput && output.size() > 1) { if (newCtx.singleOutput && output.size() > 1) {
...@@ -534,17 +536,17 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin ...@@ -534,17 +536,17 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin
&& (lambda.empty() || mLambda.at(lambda)(node.first)) && (lambda.empty() || mLambda.at(lambda)(node.first))
&& (newCtx.edgeRightIdx == gk_IODefaultIndex || node.second == newCtx.edgeRightIdx)) && (newCtx.edgeRightIdx == gk_IODefaultIndex || node.second == newCtx.edgeRightIdx))
{ {
if (mGraph->inView(node.first) && !newMatches[i].graph->inView(node.first)) { if (mGraph->inView(node.first) && !it->graph->inView(node.first)) {
newMatches[i].graph->add(node.first, false); it->graph->add(node.first, false);
if (!anchor.empty()) { if (!anchor.empty()) {
newMatches[i].anchors[type][anchor] = node.first; it->anchors[type][anchor] = node.first;
} }
newMatches[i].startNode = node.first; it->startNode = node.first;
found = true; found = true;
break; break;
} }
else if (!anchor.empty() && newMatches[i].anchors[type].find(anchor) != newMatches[i].anchors[type].end()) { else if (!anchor.empty() && it->anchors[type].find(anchor) != it->anchors[type].end()) {
newMatches[i].startNode = node.first; it->startNode = node.first;
found = true; found = true;
break; break;
} }
...@@ -558,8 +560,8 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin ...@@ -558,8 +560,8 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin
} }
else { else {
const auto inputs = (newCtx.edgeLeftIdx != gk_IODefaultIndex) const auto inputs = (newCtx.edgeLeftIdx != gk_IODefaultIndex)
? std::vector<std::pair<NodePtr, IOIndex_t>>(1, newMatches[i].startNode->input(newCtx.edgeLeftIdx)) ? std::vector<std::pair<NodePtr, IOIndex_t>>(1, it->startNode->input(newCtx.edgeLeftIdx))
: newMatches[i].startNode->inputs(); : it->startNode->inputs();
for (const auto& input : inputs) { for (const auto& input : inputs) {
if ((type.empty() || input.first->type() == type) if ((type.empty() || input.first->type() == type)
...@@ -570,17 +572,17 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin ...@@ -570,17 +572,17 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin
continue; continue;
} }
if (mGraph->inView(input.first) && !newMatches[i].graph->inView(input.first)) { if (mGraph->inView(input.first) && !it->graph->inView(input.first)) {
newMatches[i].graph->add(input.first, false); it->graph->add(input.first, false);
if (!anchor.empty()) { if (!anchor.empty()) {
newMatches[i].anchors[type][anchor] = input.first; it->anchors[type][anchor] = input.first;
} }
newMatches[i].startNode = input.first; it->startNode = input.first;
found = true; found = true;
break; break;
} }
else if (!anchor.empty() && newMatches[i].anchors[type].find(anchor) != newMatches[i].anchors[type].end()) { else if (!anchor.empty() && it->anchors[type].find(anchor) != it->anchors[type].end()) {
newMatches[i].startNode = input.first; it->startNode = input.first;
found = true; found = true;
break; break;
} }
...@@ -589,10 +591,10 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin ...@@ -589,10 +591,10 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin
} }
if (found) { if (found) {
++i; ++it;
} }
else { else {
newMatches.erase(newMatches.begin() + i); it = newMatches.erase(it);
} }
} }
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
using namespace Aidge; using namespace Aidge;
void checkMatches(const std::vector<SinglePassGraphMatching::MatchingResult>& results, const std::map<std::string, std::set<std::string>>& expected) { void checkMatches(const std::set<SinglePassGraphMatching::MatchingResult>& results, const std::map<std::string, std::set<std::string>>& expected) {
REQUIRE(results.size() == expected.size()); REQUIRE(results.size() == expected.size());
for (const auto& result : results) { for (const auto& result : results) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment