From 894e5a768881d8e0b79b4ae60dbd240cb132f867 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Sun, 19 May 2024 10:58:35 +0200
Subject: [PATCH] Added disjoint filter

---
 include/aidge/graph/Matching.hpp   | 22 ++++++++++++++++++++--
 src/graph/Matching.cpp             | 30 +++++++++++++++++++++++++++++-
 unit_tests/graph/Test_Matching.cpp |  9 +++++++++
 3 files changed, 58 insertions(+), 3 deletions(-)

diff --git a/include/aidge/graph/Matching.hpp b/include/aidge/graph/Matching.hpp
index bc36bfb28..b6aeaf2dc 100644
--- a/include/aidge/graph/Matching.hpp
+++ b/include/aidge/graph/Matching.hpp
@@ -71,6 +71,7 @@ public:
 
     /**
      * Matches a query by direct, single pass parse and match.
+     * The returned matches are non-ordered and therefore stored in a std::set.
      * 
      * Some rules:
      * - The first node of the first sequence is the root node and cannot be optional
@@ -102,7 +103,7 @@ public:
      *            inputs of the same node!
      * 
      * - In Aidge, a node output can be connected to multiple other nodes. In
-     *   your query, you can allow it or not, with the "-" or "~" modifier.
+     *   your query, you can allow it or not, with the "~" or "-" modifier.
      *   EXAMPLE: "Conv->ReLU" will match the Conv that are **only** connected
      *            to a ReLU node at their output #0.
      *            "Conv~>ReLU" will match all the Conv connected to a ReLU even
@@ -115,12 +116,23 @@ public:
      *   EXAMPLE: assume graph Conv#1->ReLU#1->Conv#2->ReLU#2
      *            "Conv->ReLU?->Conv?->ReLU?" will return both
      *            Conv#1->ReLU#1->Conv#2->ReLU#2 and Conv#2->ReLU#2
+     *   To avoid this behavior, set the disjoint argument to true. In this case,
+     *   only Conv#1->ReLU#1->Conv#2->ReLU#2 will be kept in the example above.
      * 
      * - Whitespaces are allowed anywhere in the query
      * 
      * QUERY = SEQ | NODE_OR_BLOCK (';' (SEQ | NODE_OR_BLOCK))*
+     * 
+     * @param query The query to search.
+     * @param disjoint If true, only keep the longuest disjoint (non-overlapping) matches.
+     * @return Set of matches, each stored in a MatchingResult struct.
+    */
+    std::set<MatchingResult> match(const std::string& query, bool disjoint = false);
+
+    /**
+     * Filter to keep only the longuest disjoint (non-overlapping) matches.
     */
-    std::set<MatchingResult> match(const std::string& query);
+    std::set<MatchingResult> filterLonguestDisjoint(const std::set<MatchingResult>& matches);
 
     inline void addNodeLambda(const std::string& name, bool(func)(const NodePtr&)) {
         mLambda[name] = func;
@@ -179,6 +191,12 @@ private:
                         str.end(),
                         [](char c) { return !std::isspace(c); }));
     }
+
+    struct CompareMatchingResultSize {
+        bool operator()(const MatchingResult& lhs, const MatchingResult& rhs) const {
+            return lhs.graph->getNodes().size() > rhs.graph->getNodes().size();
+        }
+    };
 };
 
 inline bool operator<(const Aidge::SinglePassGraphMatching::MatchingResult& lhs, const Aidge::SinglePassGraphMatching::MatchingResult& rhs) {
diff --git a/src/graph/Matching.cpp b/src/graph/Matching.cpp
index c855119a7..3ba80fd01 100644
--- a/src/graph/Matching.cpp
+++ b/src/graph/Matching.cpp
@@ -2,7 +2,7 @@
 
 #include <fmt/color.h>
 
-std::set<Aidge::SinglePassGraphMatching::MatchingResult> Aidge::SinglePassGraphMatching::match(const std::string& query) {
+std::set<Aidge::SinglePassGraphMatching::MatchingResult> Aidge::SinglePassGraphMatching::match(const std::string& query, bool disjoint) {
     Context ctx;
     ctx.query = query;
     std::set<MatchingResult> matches;
@@ -22,9 +22,37 @@ std::set<Aidge::SinglePassGraphMatching::MatchingResult> Aidge::SinglePassGraphM
         Log::warn("Syntax error, unable to parse remaining query: {}", ctx.query);
     }
 
+    if (disjoint) {
+        matches = filterLonguestDisjoint(matches);
+    }
+
     return matches;
 }
 
+std::set<Aidge::SinglePassGraphMatching::MatchingResult> Aidge::SinglePassGraphMatching::filterLonguestDisjoint(const std::set<MatchingResult>& matches) {
+    std::set<MatchingResult, CompareMatchingResultSize> sortedMatches(matches.begin(), matches.end());
+    std::set<NodePtr> selectedNodes;
+    std::set<MatchingResult> filteredMatches;
+
+    for (const auto& match : sortedMatches) {
+        const auto& nodes = match.graph->getNodes();
+        bool isNonOverlapping = true;
+        for (const auto& node : nodes) {
+            if (selectedNodes.find(node) != selectedNodes.end()) {
+                isNonOverlapping = false;
+                break;
+            }
+        }
+
+        if (isNonOverlapping) {
+            filteredMatches.insert(match);
+            selectedNodes.insert(nodes.begin(), nodes.end());
+        }
+    }
+
+    return filteredMatches;
+}
+
 bool Aidge::SinglePassGraphMatching::matchNodeOrBlock(Context& ctx, std::set<MatchingResult>& matches) {
     auto newCtx = ctx;
     Log::debug("{}node-or-block", std::string(2*newCtx.depth, ' '));
diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp
index 8b665ab62..93a110ebd 100644
--- a/unit_tests/graph/Test_Matching.cpp
+++ b/unit_tests/graph/Test_Matching.cpp
@@ -86,6 +86,15 @@ TEST_CASE("[core/graph] Matching") {
         });
     }
 
+    SECTION("Conv->(ReLU~>Pad->Conv)* [disjoint]") {
+        const auto results = SinglePassGraphMatching(g1).match("Conv->(ReLU~>Pad->Conv)*", true);
+
+        checkMatches(results, {
+            {"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu1", "relu2", "relu3"}},
+            {"conv5_conv", {"conv5_conv"}}
+        });
+    }
+
     SECTION("Conv~>(ReLU~>Pad->Conv)*") {
         const auto results = SinglePassGraphMatching(g1).match("Conv~>(ReLU~>Pad->Conv)*");
 
-- 
GitLab