Skip to content
Snippets Groups Projects
Commit 894e5a76 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added disjoint filter

parent 3ae83af7
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!138Alternative graph matching
Pipeline #46095 failed
......@@ -71,6 +71,7 @@ public:
/**
* Matches a query by direct, single pass parse and match.
* The returned matches are non-ordered and therefore stored in a std::set.
*
* Some rules:
* - The first node of the first sequence is the root node and cannot be optional
......@@ -102,7 +103,7 @@ public:
* inputs of the same node!
*
* - In Aidge, a node output can be connected to multiple other nodes. In
* your query, you can allow it or not, with the "-" or "~" modifier.
* your query, you can allow it or not, with the "~" or "-" modifier.
* EXAMPLE: "Conv->ReLU" will match the Conv that are **only** connected
* to a ReLU node at their output #0.
* "Conv~>ReLU" will match all the Conv connected to a ReLU even
......@@ -115,12 +116,23 @@ public:
* EXAMPLE: assume graph Conv#1->ReLU#1->Conv#2->ReLU#2
* "Conv->ReLU?->Conv?->ReLU?" will return both
* Conv#1->ReLU#1->Conv#2->ReLU#2 and Conv#2->ReLU#2
* To avoid this behavior, set the disjoint argument to true. In this case,
* only Conv#1->ReLU#1->Conv#2->ReLU#2 will be kept in the example above.
*
* - Whitespaces are allowed anywhere in the query
*
* QUERY = SEQ | NODE_OR_BLOCK (';' (SEQ | NODE_OR_BLOCK))*
*
* @param query The query to search.
* @param disjoint If true, only keep the longuest disjoint (non-overlapping) matches.
* @return Set of matches, each stored in a MatchingResult struct.
*/
std::set<MatchingResult> match(const std::string& query, bool disjoint = false);
/**
* Filter to keep only the longuest disjoint (non-overlapping) matches.
*/
std::set<MatchingResult> match(const std::string& query);
std::set<MatchingResult> filterLonguestDisjoint(const std::set<MatchingResult>& matches);
inline void addNodeLambda(const std::string& name, bool(func)(const NodePtr&)) {
mLambda[name] = func;
......@@ -179,6 +191,12 @@ private:
str.end(),
[](char c) { return !std::isspace(c); }));
}
struct CompareMatchingResultSize {
bool operator()(const MatchingResult& lhs, const MatchingResult& rhs) const {
return lhs.graph->getNodes().size() > rhs.graph->getNodes().size();
}
};
};
inline bool operator<(const Aidge::SinglePassGraphMatching::MatchingResult& lhs, const Aidge::SinglePassGraphMatching::MatchingResult& rhs) {
......
......@@ -2,7 +2,7 @@
#include <fmt/color.h>
std::set<Aidge::SinglePassGraphMatching::MatchingResult> Aidge::SinglePassGraphMatching::match(const std::string& query) {
std::set<Aidge::SinglePassGraphMatching::MatchingResult> Aidge::SinglePassGraphMatching::match(const std::string& query, bool disjoint) {
Context ctx;
ctx.query = query;
std::set<MatchingResult> matches;
......@@ -22,9 +22,37 @@ std::set<Aidge::SinglePassGraphMatching::MatchingResult> Aidge::SinglePassGraphM
Log::warn("Syntax error, unable to parse remaining query: {}", ctx.query);
}
if (disjoint) {
matches = filterLonguestDisjoint(matches);
}
return matches;
}
std::set<Aidge::SinglePassGraphMatching::MatchingResult> Aidge::SinglePassGraphMatching::filterLonguestDisjoint(const std::set<MatchingResult>& matches) {
std::set<MatchingResult, CompareMatchingResultSize> sortedMatches(matches.begin(), matches.end());
std::set<NodePtr> selectedNodes;
std::set<MatchingResult> filteredMatches;
for (const auto& match : sortedMatches) {
const auto& nodes = match.graph->getNodes();
bool isNonOverlapping = true;
for (const auto& node : nodes) {
if (selectedNodes.find(node) != selectedNodes.end()) {
isNonOverlapping = false;
break;
}
}
if (isNonOverlapping) {
filteredMatches.insert(match);
selectedNodes.insert(nodes.begin(), nodes.end());
}
}
return filteredMatches;
}
bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::set<MatchingResult>& matches) {
auto newCtx = ctx;
Log::debug("{}node-or-block", std::string(2*newCtx.depth, ' '));
......
......@@ -86,6 +86,15 @@ TEST_CASE("[core/graph] Matching") {
});
}
SECTION("Conv->(ReLU~>Pad->Conv)* [disjoint]") {
const auto results = SinglePassGraphMatching(g1).match("Conv->(ReLU~>Pad->Conv)*", true);
checkMatches(results, {
{"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu1", "relu2", "relu3"}},
{"conv5_conv", {"conv5_conv"}}
});
}
SECTION("Conv~>(ReLU~>Pad->Conv)*") {
const auto results = SinglePassGraphMatching(g1).match("Conv~>(ReLU~>Pad->Conv)*");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment