From 894e5a768881d8e0b79b4ae60dbd240cb132f867 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Sun, 19 May 2024 10:58:35 +0200 Subject: [PATCH] Added disjoint filter --- include/aidge/graph/Matching.hpp | 22 ++++++++++++++++++++-- src/graph/Matching.cpp | 30 +++++++++++++++++++++++++++++- unit_tests/graph/Test_Matching.cpp | 9 +++++++++ 3 files changed, 58 insertions(+), 3 deletions(-) diff --git a/include/aidge/graph/Matching.hpp b/include/aidge/graph/Matching.hpp index bc36bfb28..b6aeaf2dc 100644 --- a/include/aidge/graph/Matching.hpp +++ b/include/aidge/graph/Matching.hpp @@ -71,6 +71,7 @@ public: /** * Matches a query by direct, single pass parse and match. + * The returned matches are non-ordered and therefore stored in a std::set. * * Some rules: * - The first node of the first sequence is the root node and cannot be optional @@ -102,7 +103,7 @@ public: * inputs of the same node! * * - In Aidge, a node output can be connected to multiple other nodes. In - * your query, you can allow it or not, with the "-" or "~" modifier. + * your query, you can allow it or not, with the "~" or "-" modifier. * EXAMPLE: "Conv->ReLU" will match the Conv that are **only** connected * to a ReLU node at their output #0. * "Conv~>ReLU" will match all the Conv connected to a ReLU even @@ -115,12 +116,23 @@ public: * EXAMPLE: assume graph Conv#1->ReLU#1->Conv#2->ReLU#2 * "Conv->ReLU?->Conv?->ReLU?" will return both * Conv#1->ReLU#1->Conv#2->ReLU#2 and Conv#2->ReLU#2 + * To avoid this behavior, set the disjoint argument to true. In this case, + * only Conv#1->ReLU#1->Conv#2->ReLU#2 will be kept in the example above. * * - Whitespaces are allowed anywhere in the query * * QUERY = SEQ | NODE_OR_BLOCK (';' (SEQ | NODE_OR_BLOCK))* + * + * @param query The query to search. + * @param disjoint If true, only keep the longuest disjoint (non-overlapping) matches. + * @return Set of matches, each stored in a MatchingResult struct. + */ + std::set<MatchingResult> match(const std::string& query, bool disjoint = false); + + /** + * Filter to keep only the longuest disjoint (non-overlapping) matches. */ - std::set<MatchingResult> match(const std::string& query); + std::set<MatchingResult> filterLonguestDisjoint(const std::set<MatchingResult>& matches); inline void addNodeLambda(const std::string& name, bool(func)(const NodePtr&)) { mLambda[name] = func; @@ -179,6 +191,12 @@ private: str.end(), [](char c) { return !std::isspace(c); })); } + + struct CompareMatchingResultSize { + bool operator()(const MatchingResult& lhs, const MatchingResult& rhs) const { + return lhs.graph->getNodes().size() > rhs.graph->getNodes().size(); + } + }; }; inline bool operator<(const Aidge::SinglePassGraphMatching::MatchingResult& lhs, const Aidge::SinglePassGraphMatching::MatchingResult& rhs) { diff --git a/src/graph/Matching.cpp b/src/graph/Matching.cpp index c855119a7..3ba80fd01 100644 --- a/src/graph/Matching.cpp +++ b/src/graph/Matching.cpp @@ -2,7 +2,7 @@ #include <fmt/color.h> -std::set<Aidge::SinglePassGraphMatching::MatchingResult> Aidge::SinglePassGraphMatching::match(const std::string& query) { +std::set<Aidge::SinglePassGraphMatching::MatchingResult> Aidge::SinglePassGraphMatching::match(const std::string& query, bool disjoint) { Context ctx; ctx.query = query; std::set<MatchingResult> matches; @@ -22,9 +22,37 @@ std::set<Aidge::SinglePassGraphMatching::MatchingResult> Aidge::SinglePassGraphM Log::warn("Syntax error, unable to parse remaining query: {}", ctx.query); } + if (disjoint) { + matches = filterLonguestDisjoint(matches); + } + return matches; } +std::set<Aidge::SinglePassGraphMatching::MatchingResult> Aidge::SinglePassGraphMatching::filterLonguestDisjoint(const std::set<MatchingResult>& matches) { + std::set<MatchingResult, CompareMatchingResultSize> sortedMatches(matches.begin(), matches.end()); + std::set<NodePtr> selectedNodes; + std::set<MatchingResult> filteredMatches; + + for (const auto& match : sortedMatches) { + const auto& nodes = match.graph->getNodes(); + bool isNonOverlapping = true; + for (const auto& node : nodes) { + if (selectedNodes.find(node) != selectedNodes.end()) { + isNonOverlapping = false; + break; + } + } + + if (isNonOverlapping) { + filteredMatches.insert(match); + selectedNodes.insert(nodes.begin(), nodes.end()); + } + } + + return filteredMatches; +} + bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::set<MatchingResult>& matches) { auto newCtx = ctx; Log::debug("{}node-or-block", std::string(2*newCtx.depth, ' ')); diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp index 8b665ab62..93a110ebd 100644 --- a/unit_tests/graph/Test_Matching.cpp +++ b/unit_tests/graph/Test_Matching.cpp @@ -86,6 +86,15 @@ TEST_CASE("[core/graph] Matching") { }); } + SECTION("Conv->(ReLU~>Pad->Conv)* [disjoint]") { + const auto results = SinglePassGraphMatching(g1).match("Conv->(ReLU~>Pad->Conv)*", true); + + checkMatches(results, { + {"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu1", "relu2", "relu3"}}, + {"conv5_conv", {"conv5_conv"}} + }); + } + SECTION("Conv~>(ReLU~>Pad->Conv)*") { const auto results = SinglePassGraphMatching(g1).match("Conv~>(ReLU~>Pad->Conv)*"); -- GitLab