From 46d53493bae91e137f309c7f6a005b2347120b9d Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Fri, 17 May 2024 18:37:26 +0200
Subject: [PATCH] Changed default behavior to single output matching and added
 modifier to match multiple outputs

---
 include/aidge/graph/Matching.hpp   |  14 ++-
 src/graph/Matching.cpp             |  14 ++-
 unit_tests/graph/Test_Matching.cpp | 136 +++++++++++++++++++++++++----
 3 files changed, 145 insertions(+), 19 deletions(-)

diff --git a/include/aidge/graph/Matching.hpp b/include/aidge/graph/Matching.hpp
index 42bb0b6ef..004011481 100644
--- a/include/aidge/graph/Matching.hpp
+++ b/include/aidge/graph/Matching.hpp
@@ -33,6 +33,7 @@ public:
         bool firstNode = true;
         bool inSequence = false;
         bool lookForChild = true;
+        bool singleOutput = true;
         IOIndex_t edgeLeftIdx = 0;
         IOIndex_t edgeRightIdx = 0;
         size_t depth = 0;
@@ -95,6 +96,15 @@ public:
      *            Note that the anchor is required since we intend to match several
      *            inputs of the same node!
      * 
+     * - In Aidge, a node output can be connected to multiple other nodes. In
+     *   your query, you can allow it or not, with the "-" or "~" modifier.
+     *   EXAMPLE: "Conv->ReLU" will match the Conv that are **only** connected
+     *            to a ReLU node at their output #0.
+     *            "Conv~>ReLU" will match all the Conv connected to a ReLU even
+     *            if they are also connected to other nodes at the same output #0.
+     *   When implementing a match & replace recipe, beware that you don't break
+     *   branches in the middle of your matching result if you use "~"!
+     * 
      * - The matching results can be overlapping, meaning that some nodes may be
      *   found in multiple results. Some results may be subsets of other results.
      *   EXAMPLE: assume graph Conv#1->ReLU#1->Conv#2->ReLU#2
@@ -144,8 +154,8 @@ private:
     /**
      * IO_INDEX_ANY = '*'
      * IO_INDEX = IO_INDEX_ANY | [0-9]+
-     * CHILD_EDGE = '-' (IO_INDEX '-')? IO_INDEX? '>'
-     * PARENT_EDGE = '<' (IO_INDEX '-')? IO_INDEX? '-'
+     * CHILD_EDGE = ('-' | '~') (IO_INDEX '-')? IO_INDEX? '>'
+     * PARENT_EDGE = '<' (IO_INDEX '-')? IO_INDEX? ('-' | '~')
      * EDGE = CHILD_EDGE | PARENT_EDGE
     */
     bool matchEdge(Context& ctx, std::vector<MatchingResult>& matches);
diff --git a/src/graph/Matching.cpp b/src/graph/Matching.cpp
index 3b1c4806a..aae317525 100644
--- a/src/graph/Matching.cpp
+++ b/src/graph/Matching.cpp
@@ -326,7 +326,8 @@ bool Aidge::SinglePassGraphMatching::matchEdge(Context& ctx, std::vector<Matchin
     Log::debug("{}edge", std::string(2*newCtx.depth, ' '));
 
     removeWhiteSpace(newCtx.query);
-    if (!newCtx.query.empty() && newCtx.query[0] == '-') {
+    if (!newCtx.query.empty() && (newCtx.query[0] == '-' || newCtx.query[0] == '~')) {
+        newCtx.singleOutput = (newCtx.query[0] == '-');
         newCtx.query.erase(0, 1); // drop '-'
         newCtx.lookForChild = true;
     }
@@ -381,7 +382,8 @@ bool Aidge::SinglePassGraphMatching::matchEdge(Context& ctx, std::vector<Matchin
     if (newCtx.lookForChild && !newCtx.query.empty() && newCtx.query[0] == '>') {
         newCtx.query.erase(0, 1); // drop '>'
     }
-    else if (!newCtx.lookForChild && !newCtx.query.empty() && newCtx.query[0] == '-') {
+    else if (!newCtx.lookForChild && !newCtx.query.empty() && (newCtx.query[0] == '-' || newCtx.query[0] == '~')) {
+        newCtx.singleOutput = (newCtx.query[0] == '-');
         newCtx.query.erase(0, 1); // drop '-'
     }
     else {
@@ -523,6 +525,10 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin
                     : newMatches[i].startNode->outputs();
 
                 for (const auto& output : outputs) {
+                    if (newCtx.singleOutput && output.size() > 1) {
+                        continue;
+                    }
+
                     for (const auto& node : output) {
                         if ((type.empty() || node.first->type() == type)
                             && (lambda.empty() || mLambda.at(lambda)(node.first))
@@ -560,6 +566,10 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin
                         && (lambda.empty() || mLambda.at(lambda)(input.first))
                         && (newCtx.edgeRightIdx == gk_IODefaultIndex || input.second == newCtx.edgeRightIdx))
                     {
+                        if (newCtx.singleOutput && input.first->getChildren(input.second).size() > 1) {
+                            continue;
+                        }
+
                         if (mGraph->inView(input.first) && !newMatches[i].graph->inView(input.first)) {
                             newMatches[i].graph->add(input.first, false);
                             if (!anchor.empty()) {
diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp
index bc831c7b9..c9d7faa81 100644
--- a/unit_tests/graph/Test_Matching.cpp
+++ b/unit_tests/graph/Test_Matching.cpp
@@ -65,6 +65,30 @@ TEST_CASE("[core/graph] Matching") {
     SECTION("Conv->(ReLU->Pad->Conv)*") {
         const auto results = SinglePassGraphMatching(g1).match("Conv->(ReLU->Pad->Conv)*");
 
+        checkMatches(results, {
+            {"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "relu1", "relu2"}},
+            {"conv2_conv", {"conv2_conv", "conv3_conv", "conv3_pad", "relu2"}},
+            {"conv3_conv", {"conv3_conv"}},
+            {"conv4_conv", {"conv4_conv"}},
+            {"conv5_conv", {"conv5_conv"}}
+        });
+    }
+
+    SECTION("Conv->(ReLU~>Pad->Conv)*") {
+        const auto results = SinglePassGraphMatching(g1).match("Conv->(ReLU~>Pad->Conv)*");
+
+        checkMatches(results, {
+            {"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu1", "relu2", "relu3"}},
+            {"conv2_conv", {"conv2_conv", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu2", "relu3"}},
+            {"conv3_conv", {"conv3_conv", "conv4_conv", "conv4_pad", "relu3"}},
+            {"conv4_conv", {"conv4_conv"}},
+            {"conv5_conv", {"conv5_conv"}}
+        });
+    }
+
+    SECTION("Conv~>(ReLU~>Pad->Conv)*") {
+        const auto results = SinglePassGraphMatching(g1).match("Conv~>(ReLU~>Pad->Conv)*");
+
         checkMatches(results, {
             {"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu1", "relu2", "relu3"}},
             {"conv2_conv", {"conv2_conv", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu2", "relu3"}},
@@ -77,6 +101,25 @@ TEST_CASE("[core/graph] Matching") {
     SECTION("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer") {
         const auto results = SinglePassGraphMatching(g1).match("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer");
 
+        checkMatches(results, {
+            {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
+            {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}
+        });
+    }
+
+    SECTION("Pad->Conv#~>ReLU;Conv#<1-Producer;Conv#<2-Producer") {
+        const auto results = SinglePassGraphMatching(g1).match("Pad->Conv#~>ReLU;Conv#<1-Producer;Conv#<2-Producer");
+
+        checkMatches(results, {
+            {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
+            {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}},
+            {"conv5_pad", {"conv5_b", "conv5_conv", "conv5_pad", "conv5_w", "relu5"}}
+        });
+    }
+
+    SECTION("Pad->Conv#~>ReLU;(Conv#<*-Producer){2}") {
+        const auto results = SinglePassGraphMatching(g1).match("Pad->Conv#~>ReLU;(Conv#<*-Producer){2}");
+
         checkMatches(results, {
             {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
             {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}},
@@ -87,6 +130,15 @@ TEST_CASE("[core/graph] Matching") {
     SECTION("Pad->Conv#->ReLU;(Conv#<*-Producer){2}") {
         const auto results = SinglePassGraphMatching(g1).match("Pad->Conv#->ReLU;(Conv#<*-Producer){2}");
 
+        checkMatches(results, {
+            {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
+            {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}
+        });
+    }
+
+    SECTION("Pad->Conv#~>ReLU;(Conv#<*-.){2}") {
+        const auto results = SinglePassGraphMatching(g1).match("Pad->Conv#~>ReLU;(Conv#<*-.){2}");
+
         checkMatches(results, {
             {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
             {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}},
@@ -99,8 +151,19 @@ TEST_CASE("[core/graph] Matching") {
 
         checkMatches(results, {
             {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
-            {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}},
-            {"conv5_pad", {"conv5_b", "conv5_conv", "conv5_pad", "conv5_w", "relu5"}}
+            {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}
+        });
+    }
+
+    SECTION("Conv#~>ReLU*;Conv#<-Pad*") {
+        const auto results = SinglePassGraphMatching(g1).match("Conv#~>ReLU*;Conv#<-Pad*");
+
+        checkMatches(results, {
+            {"conv1", {"conv1", "relu1"}},
+            {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
+            {"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
+            {"conv4_conv", {"conv4_conv", "conv4_pad"}},
+            {"conv5_conv", {"conv5_conv", "conv5_pad", "relu5"}}
         });
     }
 
@@ -112,13 +175,37 @@ TEST_CASE("[core/graph] Matching") {
             {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
             {"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
             {"conv4_conv", {"conv4_conv", "conv4_pad"}},
-            {"conv5_conv", {"conv5_conv", "conv5_pad", "relu5"}}
+            {"conv5_conv", {"conv5_conv", "conv5_pad"}}
         });
     }
 
     SECTION("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?") {
         const auto results = SinglePassGraphMatching(g1).match("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?");
 
+        checkMatches(results, {
+            {"conv1", {"conv1", "relu1"}},
+            {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
+            {"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
+            {"conv4_conv", {"add", "conv4_conv", "conv4_pad"}},
+            {"conv5_conv", {"conv5_conv", "conv5_pad"}}
+        });
+    }
+
+    SECTION("Conv#~>ReLU?-*>Add#1?~>ReLU?;Conv#<-Pad?;(Add#1<*-.)?") {
+        const auto results = SinglePassGraphMatching(g1).match("Conv#~>ReLU?-*>Add#1?~>ReLU?;Conv#<-Pad?;(Add#1<*-.)?");
+
+        checkMatches(results, {
+            {"conv1", {"conv1", "relu1"}},
+            {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
+            {"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
+            {"conv4_conv", {"add", "conv4_conv", "conv4_pad"}},
+            {"conv5_conv", {"add2", "conv5_conv", "conv5_pad", "relu5"}}
+        });
+    }
+
+    SECTION("Conv#~>ReLU?~*>Add#1?~>ReLU?;Conv#<-Pad?;(Add#1<*~.)?") {
+        const auto results = SinglePassGraphMatching(g1).match("Conv#~>ReLU?~*>Add#1?~>ReLU?;Conv#<-Pad?;(Add#1<*~.)?");
+
         checkMatches(results, {
             {"conv1", {"conv1", "relu1"}},
             {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
@@ -131,6 +218,18 @@ TEST_CASE("[core/graph] Matching") {
     SECTION("Conv#->ReLU?;Conv#<-Pad?") {
         const auto results = SinglePassGraphMatching(g1).match("Conv#->ReLU?;Conv#<-Pad?");
 
+        checkMatches(results, {
+            {"conv1", {"conv1", "relu1"}},
+            {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
+            {"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
+            {"conv4_conv", {"conv4_conv", "conv4_pad"}},
+            {"conv5_conv", {"conv5_conv", "conv5_pad"}}
+        });
+    }
+
+    SECTION("Conv#~>ReLU?;Conv#<-Pad?") {
+        const auto results = SinglePassGraphMatching(g1).match("Conv#~>ReLU?;Conv#<-Pad?");
+
         checkMatches(results, {
             {"conv1", {"conv1", "relu1"}},
             {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
@@ -158,8 +257,8 @@ TEST_CASE("[core/graph] Matching") {
         });
     }
 
-    SECTION("(Add#<*-.)+") {
-        const auto results = SinglePassGraphMatching(g1).match("(Add#<*-.)+");
+    SECTION("(Add#<*~.)+") {
+        const auto results = SinglePassGraphMatching(g1).match("(Add#<*~.)+");
 
         checkMatches(results, {
             {"add", {"add", "conv4_conv", "relu3"}},
@@ -167,21 +266,21 @@ TEST_CASE("[core/graph] Matching") {
         });
     }
 
-    SECTION("Conv-*>(ReLU&Add)") {
-        const auto results = SinglePassGraphMatching(g1).match("Conv-*>(ReLU&Add)");
+    SECTION("Conv~*>(ReLU&Add)") {
+        const auto results = SinglePassGraphMatching(g1).match("Conv~*>(ReLU&Add)");
 
         checkMatches(results, {
             {"conv5_conv", {"add2", "conv5_conv", "relu5"}}
         });
     }
 
-    SECTION("Conv->(ReLU&Add)") {
-        const auto results = SinglePassGraphMatching(g1).match("Conv->(ReLU&Add)");
+    SECTION("Conv~>(ReLU&Add)") {
+        const auto results = SinglePassGraphMatching(g1).match("Conv~>(ReLU&Add)");
         REQUIRE(results.size() == 0);
     }
 
-    SECTION("ReLU-*>((Pad->Conv-*>Add#)&Add#)") {
-        const auto results = SinglePassGraphMatching(g1).match("ReLU-*>((Pad->Conv-*>Add#)&Add#)");
+    SECTION("ReLU~*>((Pad->Conv-*>Add#)&Add#)") {
+        const auto results = SinglePassGraphMatching(g1).match("ReLU~*>((Pad->Conv-*>Add#)&Add#)");
 
         checkMatches(results, {
             {"relu3", {"add", "conv4_conv", "conv4_pad", "relu3"}}
@@ -223,22 +322,29 @@ TEST_CASE("[core/graph] Matching") {
     }
 
     SECTION("Conv->ReLU [perf]") {
-        const size_t nbTests = 10;
+        const size_t nbTests = 3;
         std::mt19937::result_type seed(1);
 
         for (int test = 0; test < nbTests; ++test) {
             RandomGraph randGraph;
-            randGraph.types = {"Conv", "ReLU"};
-            randGraph.typesWeights = {0.9, 0.1};
+            randGraph.types = {"Conv", "ReLU", "Dummy"};
+            randGraph.typesWeights = {0.4, 0.4, 0.2};
+            randGraph.avgIn = 1;
+            randGraph.maxIn = 1;
+            randGraph.maxOut = 1;
+            randGraph.avgOut = 1;
+            randGraph.density = 0.9;
+            randGraph.acyclic = true;
             const auto g1 = std::make_shared<GraphView>("g1");
 
             Log::setConsoleLevel(Log::Warn);
             g1->add(randGraph.gen(seed, 100));
+            g1->save("graph_single_pass");
 
             auto gm = SinglePassGraphMatching(g1);
             
             const auto start = std::chrono::system_clock::now();
-            const auto results = gm.match("Conv->ReLU");
+            const auto results = gm.match("Conv->ReLU#;ReLU#->Dummy");
             const auto end = std::chrono::system_clock::now();
             const auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
 
-- 
GitLab