From 79ed9f8384597c5cb9933560b0655ee720a68db5 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Thu, 16 May 2024 10:47:31 +0200
Subject: [PATCH] Improved tests

---
 include/aidge/graph/Matching.hpp   |   2 +-
 unit_tests/graph/Test_Matching.cpp | 142 ++++++++++++++++-------------
 2 files changed, 79 insertions(+), 65 deletions(-)

diff --git a/include/aidge/graph/Matching.hpp b/include/aidge/graph/Matching.hpp
index 8a75a8539..983967fb9 100644
--- a/include/aidge/graph/Matching.hpp
+++ b/include/aidge/graph/Matching.hpp
@@ -99,7 +99,7 @@ public:
     */
     std::vector<MatchingResult> match(const std::string& query);
 
-    inline void addLambda(const std::string& name, bool(func)(const NodePtr&)) {
+    inline void addNodeLambda(const std::string& name, bool(func)(const NodePtr&)) {
         mLambda[name] = func;
     }
 
diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp
index c1d84f822..6c5abcc47 100644
--- a/unit_tests/graph/Test_Matching.cpp
+++ b/unit_tests/graph/Test_Matching.cpp
@@ -23,6 +23,19 @@
 
 using namespace Aidge;
 
+void checkMatches(const std::vector<GraphMatching::MatchingResult>& results, const std::map<std::string, std::set<std::string>>& expected) {
+    REQUIRE(results.size() == expected.size());
+
+    for (const auto& result : results) {
+        const auto found = nodePtrTo(result.graph->getNodes(), nodePtrToName);
+        fmt::print("Found: {}\n", found);
+
+        const auto rootNode = result.graph->rootNode()->name();
+        const auto expectedSet = expected.at(rootNode);
+        REQUIRE(found == expectedSet);
+    }
+}
+
 TEST_CASE("[core/graph] Matching") {
     auto g1 = Sequential({
         Producer({16, 3, 512, 512}, "dataProvider"),
@@ -49,110 +62,115 @@ TEST_CASE("[core/graph] Matching") {
 
     SECTION("Conv->(ReLU->Pad->Conv)*") {
         const auto results = GraphMatching(g1).match("Conv->(ReLU->Pad->Conv)*");
-        REQUIRE(results.size() == 5);
 
-        for (const auto& result : results) {
-            fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
-        }
+        checkMatches(results, {
+            {"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu1", "relu2", "relu3"}},
+            {"conv2_conv", {"conv2_conv", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu2", "relu3"}},
+            {"conv3_conv", {"conv3_conv", "conv4_conv", "conv4_pad", "relu3"}},
+            {"conv4_conv", {"conv4_conv"}},
+            {"conv5_conv", {"conv5_conv"}}
+        });
     }
 
     SECTION("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer") {
         const auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer");
-        REQUIRE(results.size() == 3);
 
-        for (const auto& result : results) {
-            fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
-            REQUIRE(result.graph->getNodes().size() == 5);
-        }
+        checkMatches(results, {
+            {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
+            {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}},
+            {"conv5_pad", {"conv5_b", "conv5_conv", "conv5_pad", "conv5_w", "relu5"}}
+        });
     }
 
     SECTION("Pad->Conv#->ReLU;(Conv#<*-Producer){2}") {
         const auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;(Conv#<*-Producer){2}");
-        REQUIRE(results.size() == 3);
 
-        for (const auto& result : results) {
-            fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
-            REQUIRE(result.graph->getNodes().size() == 5);
-        }
+        checkMatches(results, {
+            {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
+            {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}},
+            {"conv5_pad", {"conv5_b", "conv5_conv", "conv5_pad", "conv5_w", "relu5"}}
+        });
     }
 
     SECTION("Pad->Conv#->ReLU;(Conv#<*-.){2}") {
         const auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;(Conv#<*-.){2}");
-        REQUIRE(results.size() == 3);
 
-        for (const auto& result : results) {
-            fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
-            REQUIRE(result.graph->getNodes().size() == 5);
-        }
+        checkMatches(results, {
+            {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
+            {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}},
+            {"conv5_pad", {"conv5_b", "conv5_conv", "conv5_pad", "conv5_w", "relu5"}}
+        });
     }
 
     SECTION("Conv#->ReLU*;Conv#<-Pad*") {
         const auto results = GraphMatching(g1).match("Conv#->ReLU*;Conv#<-Pad*");
-        REQUIRE(results.size() == 5);
 
-        for (const auto& result : results) {
-            fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
-            REQUIRE((result.graph->getNodes().size() == 2 || result.graph->getNodes().size() == 3));
-        }
+        checkMatches(results, {
+            {"conv1", {"conv1", "relu1"}},
+            {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
+            {"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
+            {"conv4_conv", {"conv4_conv", "conv4_pad"}},
+            {"conv5_conv", {"conv5_conv", "conv5_pad", "relu5"}}
+        });
     }
 
     SECTION("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?") {
         const auto results = GraphMatching(g1).match("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?");
-        REQUIRE(results.size() == 5);
 
-        for (const auto& result : results) {
-            fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
-        }
+        checkMatches(results, {
+            {"conv1", {"conv1", "relu1"}},
+            {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
+            {"conv3_conv", {"add", "conv3_conv", "conv3_pad", "conv4_conv", "relu3"}},
+            {"conv4_conv", {"add", "conv4_conv", "conv4_pad", "relu3"}},
+            {"conv5_conv", {"add2", "conv5_conv", "conv5_pad", "relu5"}}
+        });
     }
 
     SECTION("Conv#->ReLU?;Conv#<-Pad?") {
         const auto results = GraphMatching(g1).match("Conv#->ReLU?;Conv#<-Pad?");
-        REQUIRE(results.size() == 5);
 
-        for (const auto& result : results) {
-            fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
-            REQUIRE((result.graph->getNodes().size() == 2 || result.graph->getNodes().size() == 3));
-        }
+        checkMatches(results, {
+            {"conv1", {"conv1", "relu1"}},
+            {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
+            {"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
+            {"conv4_conv", {"conv4_conv", "conv4_pad"}},
+            {"conv5_conv", {"conv5_conv", "conv5_pad", "relu5"}}
+        });
     }
 
     SECTION("(Conv|ReLU)->Add") {
         const auto results = GraphMatching(g1).match("(Conv|ReLU)->Add");
-        REQUIRE(results.size() == 2);
 
-        for (const auto& result : results) {
-            fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
-            REQUIRE(result.graph->getNodes().size() == 2);
-        }
+        checkMatches(results, {
+            {"conv4_conv", {"add", "conv4_conv"}},
+            {"relu5", {"add2", "relu5"}}
+        });
     }
 
     SECTION("Add<*-.") {
         const auto results = GraphMatching(g1).match("Add<*-.");
-        REQUIRE(results.size() == 2);
 
-        for (const auto& result : results) {
-            fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
-            REQUIRE(result.graph->getNodes().size() == 2);
-        }
+        checkMatches(results, {
+            {"add", {"add", "conv4_conv"}},
+            {"add2", {"add2", "relu5"}}
+        });
     }
 
     SECTION("(Add#<*-.)+") {
         const auto results = GraphMatching(g1).match("(Add#<*-.)+");
-        REQUIRE(results.size() == 2);
 
-        for (const auto& result : results) {
-            fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
-            REQUIRE(result.graph->getNodes().size() == 3);
-        }
+        checkMatches(results, {
+            {"add", {"add", "conv4_conv", "relu3"}},
+            {"add2", {"add2", "conv5_conv", "relu5"}}
+        });
     }
 
     SECTION("Conv-*>(ReLU&Add)") {
         const auto results = GraphMatching(g1).match("Conv-*>(ReLU&Add)");
-        REQUIRE(results.size() == 1);
 
-        for (const auto& result : results) {
-            fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
-            REQUIRE(result.graph->getNodes().size() == 3);
-        }
+        checkMatches(results, {
+            {"conv5_conv", {"add2", "conv5_conv", "relu5"}}
+        });
     }
 
     SECTION("Conv->(ReLU&Add)") {
@@ -162,12 +180,10 @@ TEST_CASE("[core/graph] Matching") {
 
     SECTION("ReLU-*>((Pad->Conv-*>Add#)&Add#)") {
         const auto results = GraphMatching(g1).match("ReLU-*>((Pad->Conv-*>Add#)&Add#)");
-        REQUIRE(results.size() == 1);
 
-        for (const auto& result : results) {
-            fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
-            REQUIRE(result.graph->getNodes().size() == 4);
-        }
+        checkMatches(results, {
+            {"relu3", {"add", "conv4_conv", "conv4_pad", "relu3"}}
+        });
     }
 
     SECTION("ReLU-*>((Pad->Conv-*>Add)&Add)") {
@@ -177,18 +193,16 @@ TEST_CASE("[core/graph] Matching") {
 
     SECTION("Pad->Conv[3x3]->ReLU") {
         auto gm = GraphMatching(g1);
-        gm.addLambda("3x3", [](const NodePtr& node) {
+        gm.addNodeLambda("3x3", [](const NodePtr& node) {
             const std::shared_ptr<Conv_Op<2>> op =
                 std::static_pointer_cast<Conv_Op<2>>(node->getOperator());
             return (op->getAttr<std::array<DimSize_t, 2>>("KernelDims") == std::array<DimSize_t, 2>({3, 3}));
         });
 
         const auto results = gm.match("Pad->Conv[3x3]->ReLU");
-        REQUIRE(results.size() == 1);
 
-        for (const auto& result : results) {
-            fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
-            REQUIRE(result.graph->getNodes().size() == 3);
-        }
+        checkMatches(results, {
+            {"conv3_pad", {"conv3_conv", "conv3_pad", "relu3"}}
+        });
     }
 }
-- 
GitLab