From 1c104764991780527cb2103899a3e13800e60932 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Wed, 15 May 2024 18:29:31 +0200
Subject: [PATCH] Added lambda support

---
 include/aidge/graph/Matching.hpp   |  8 +++++++-
 src/graph/Matching.cpp             | 30 +++++++++++++++++++++++++++++-
 unit_tests/graph/Test_Matching.cpp | 19 ++++++++++++++++++-
 3 files changed, 54 insertions(+), 3 deletions(-)

diff --git a/include/aidge/graph/Matching.hpp b/include/aidge/graph/Matching.hpp
index 4b4f9a156..8a75a8539 100644
--- a/include/aidge/graph/Matching.hpp
+++ b/include/aidge/graph/Matching.hpp
@@ -99,8 +99,13 @@ public:
     */
     std::vector<MatchingResult> match(const std::string& query);
 
+    inline void addLambda(const std::string& name, bool(func)(const NodePtr&)) {
+        mLambda[name] = func;
+    }
+
 private:
     std::shared_ptr<GraphView> mGraph;
+    std::map<std::string, bool(*)(const NodePtr&)> mLambda;
 
     /**
      * NODE_OR_BLOCK = BLOCK | NODE
@@ -140,7 +145,8 @@ private:
     /**
      * TYPE = [A-Za-z0-9_]+
      * ANCHOR = [A-Za-z0-9_]+
-     * NODE = (TYPE | '.') ('#' ANCHOR)?
+     * LAMBDA = [A-Za-z0-9_]+
+     * NODE = (TYPE | '.') ('#' ANCHOR)? ('[' LAMBDA ']')?
     */
     bool matchNode(Context& ctx, std::vector<MatchingResult>& matches);
 
diff --git a/src/graph/Matching.cpp b/src/graph/Matching.cpp
index d0397fa64..28d4fa3c7 100644
--- a/src/graph/Matching.cpp
+++ b/src/graph/Matching.cpp
@@ -454,10 +454,36 @@ bool Aidge::GraphMatching::matchNode(Context& ctx, std::vector<MatchingResult>&
         newCtx.query = newCtx.query.substr(endAnchor - newCtx.query.begin());
     }
 
+    std::string lambda = "";
+    if (!newCtx.query.empty() && newCtx.query[0] == '[') {
+        newCtx.query.erase(0, 1);
+
+        const auto endIdentifier = std::find_if(newCtx.query.begin(), newCtx.query.end(), 
+            [](char c) { return (!isalnum(c) && c != '_'); });
+
+        if (endIdentifier == newCtx.query.begin()) {
+            Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
+            return false;
+        }
+
+        lambda = newCtx.query.substr(0, endIdentifier - newCtx.query.begin());
+        newCtx.query = newCtx.query.substr(endIdentifier - newCtx.query.begin());
+
+        if (!newCtx.query.empty() && newCtx.query[0] == ']') {
+            newCtx.query.erase(0, 1);
+        }
+        else {
+            Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
+            return false;
+        }
+    }
+
     if (newCtx.firstSequence && newCtx.firstNode) {
         // First node of first sequence = root node
         for (auto node : mGraph->getNodes()) {
-            if (type.empty() || node->type() == type) {
+            if ((type.empty() || node->type() == type)
+                && (lambda.empty() || mLambda.at(lambda)(node)))
+            {
                 MatchingResult result;
                 result.graph->add(node, false);
                 if (!anchor.empty()) {
@@ -499,6 +525,7 @@ bool Aidge::GraphMatching::matchNode(Context& ctx, std::vector<MatchingResult>&
                 for (const auto& output : outputs) {
                     for (const auto& node : output) {
                         if ((type.empty() || node.first->type() == type)
+                            && (lambda.empty() || mLambda.at(lambda)(node.first))
                             && (newCtx.edgeRightIdx == gk_IODefaultIndex || node.second == newCtx.edgeRightIdx))
                         {
                             if (mGraph->inView(node.first) && !newMatches[i].graph->inView(node.first)) {
@@ -530,6 +557,7 @@ bool Aidge::GraphMatching::matchNode(Context& ctx, std::vector<MatchingResult>&
 
                 for (const auto& input : inputs) {
                     if ((type.empty() || input.first->type() == type)
+                        && (lambda.empty() || mLambda.at(lambda)(input.first))
                         && (newCtx.edgeRightIdx == gk_IODefaultIndex || input.second == newCtx.edgeRightIdx))
                     {
                         if (mGraph->inView(input.first) && !newMatches[i].graph->inView(input.first)) {
diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp
index 10d2da45a..c1d84f822 100644
--- a/unit_tests/graph/Test_Matching.cpp
+++ b/unit_tests/graph/Test_Matching.cpp
@@ -30,7 +30,7 @@ TEST_CASE("[core/graph] Matching") {
         ReLU("relu1"),
         PaddedConv(4, 8, {5, 5}, "conv2", {1, 1}, {2, 2, 2, 2}),
         ReLU("relu2"),
-        PaddedConv(8, 16, {5, 5}, "conv3", {1, 1}, {2, 2, 2, 2}),
+        PaddedConv(8, 16, {3, 3}, "conv3", {1, 1}, {2, 2, 2, 2}),
         ReLU("relu3"),
         PaddedConv(8, 16, {5, 5}, "conv4", {1, 1}, {2, 2, 2, 2}),
         Add(2, "add"),
@@ -174,4 +174,21 @@ TEST_CASE("[core/graph] Matching") {
         const auto results = GraphMatching(g1).match("ReLU-*>((Pad->Conv-*>Add)&Add)");
         REQUIRE(results.size() == 0);
     }
+
+    SECTION("Pad->Conv[3x3]->ReLU") {
+        auto gm = GraphMatching(g1);
+        gm.addLambda("3x3", [](const NodePtr& node) {
+            const std::shared_ptr<Conv_Op<2>> op =
+                std::static_pointer_cast<Conv_Op<2>>(node->getOperator());
+            return (op->getAttr<std::array<DimSize_t, 2>>("KernelDims") == std::array<DimSize_t, 2>({3, 3}));
+        });
+
+        const auto results = gm.match("Pad->Conv[3x3]->ReLU");
+        REQUIRE(results.size() == 1);
+
+        for (const auto& result : results) {
+            fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
+            REQUIRE(result.graph->getNodes().size() == 3);
+        }
+    }
 }
-- 
GitLab