From 79ed9f8384597c5cb9933560b0655ee720a68db5 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Thu, 16 May 2024 10:47:31 +0200 Subject: [PATCH] Improved tests --- include/aidge/graph/Matching.hpp | 2 +- unit_tests/graph/Test_Matching.cpp | 142 ++++++++++++++++------------- 2 files changed, 79 insertions(+), 65 deletions(-) diff --git a/include/aidge/graph/Matching.hpp b/include/aidge/graph/Matching.hpp index 8a75a8539..983967fb9 100644 --- a/include/aidge/graph/Matching.hpp +++ b/include/aidge/graph/Matching.hpp @@ -99,7 +99,7 @@ public: */ std::vector<MatchingResult> match(const std::string& query); - inline void addLambda(const std::string& name, bool(func)(const NodePtr&)) { + inline void addNodeLambda(const std::string& name, bool(func)(const NodePtr&)) { mLambda[name] = func; } diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp index c1d84f822..6c5abcc47 100644 --- a/unit_tests/graph/Test_Matching.cpp +++ b/unit_tests/graph/Test_Matching.cpp @@ -23,6 +23,19 @@ using namespace Aidge; +void checkMatches(const std::vector<GraphMatching::MatchingResult>& results, const std::map<std::string, std::set<std::string>>& expected) { + REQUIRE(results.size() == expected.size()); + + for (const auto& result : results) { + const auto found = nodePtrTo(result.graph->getNodes(), nodePtrToName); + fmt::print("Found: {}\n", found); + + const auto rootNode = result.graph->rootNode()->name(); + const auto expectedSet = expected.at(rootNode); + REQUIRE(found == expectedSet); + } +} + TEST_CASE("[core/graph] Matching") { auto g1 = Sequential({ Producer({16, 3, 512, 512}, "dataProvider"), @@ -49,110 +62,115 @@ TEST_CASE("[core/graph] Matching") { SECTION("Conv->(ReLU->Pad->Conv)*") { const auto results = GraphMatching(g1).match("Conv->(ReLU->Pad->Conv)*"); - REQUIRE(results.size() == 5); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - } + checkMatches(results, { + {"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu1", "relu2", "relu3"}}, + {"conv2_conv", {"conv2_conv", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu2", "relu3"}}, + {"conv3_conv", {"conv3_conv", "conv4_conv", "conv4_pad", "relu3"}}, + {"conv4_conv", {"conv4_conv"}}, + {"conv5_conv", {"conv5_conv"}} + }); } SECTION("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer") { const auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer"); - REQUIRE(results.size() == 3); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - REQUIRE(result.graph->getNodes().size() == 5); - } + checkMatches(results, { + {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}}, + {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}, + {"conv5_pad", {"conv5_b", "conv5_conv", "conv5_pad", "conv5_w", "relu5"}} + }); } SECTION("Pad->Conv#->ReLU;(Conv#<*-Producer){2}") { const auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;(Conv#<*-Producer){2}"); - REQUIRE(results.size() == 3); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - REQUIRE(result.graph->getNodes().size() == 5); - } + checkMatches(results, { + {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}}, + {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}, + {"conv5_pad", {"conv5_b", "conv5_conv", "conv5_pad", "conv5_w", "relu5"}} + }); } SECTION("Pad->Conv#->ReLU;(Conv#<*-.){2}") { const auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;(Conv#<*-.){2}"); - REQUIRE(results.size() == 3); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - REQUIRE(result.graph->getNodes().size() == 5); - } + checkMatches(results, { + {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}}, + {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}, + {"conv5_pad", {"conv5_b", "conv5_conv", "conv5_pad", "conv5_w", "relu5"}} + }); } SECTION("Conv#->ReLU*;Conv#<-Pad*") { const auto results = GraphMatching(g1).match("Conv#->ReLU*;Conv#<-Pad*"); - REQUIRE(results.size() == 5); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - REQUIRE((result.graph->getNodes().size() == 2 || result.graph->getNodes().size() == 3)); - } + checkMatches(results, { + {"conv1", {"conv1", "relu1"}}, + {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}}, + {"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}}, + {"conv4_conv", {"conv4_conv", "conv4_pad"}}, + {"conv5_conv", {"conv5_conv", "conv5_pad", "relu5"}} + }); } SECTION("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?") { const auto results = GraphMatching(g1).match("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?"); - REQUIRE(results.size() == 5); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - } + checkMatches(results, { + {"conv1", {"conv1", "relu1"}}, + {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}}, + {"conv3_conv", {"add", "conv3_conv", "conv3_pad", "conv4_conv", "relu3"}}, + {"conv4_conv", {"add", "conv4_conv", "conv4_pad", "relu3"}}, + {"conv5_conv", {"add2", "conv5_conv", "conv5_pad", "relu5"}} + }); } SECTION("Conv#->ReLU?;Conv#<-Pad?") { const auto results = GraphMatching(g1).match("Conv#->ReLU?;Conv#<-Pad?"); - REQUIRE(results.size() == 5); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - REQUIRE((result.graph->getNodes().size() == 2 || result.graph->getNodes().size() == 3)); - } + checkMatches(results, { + {"conv1", {"conv1", "relu1"}}, + {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}}, + {"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}}, + {"conv4_conv", {"conv4_conv", "conv4_pad"}}, + {"conv5_conv", {"conv5_conv", "conv5_pad", "relu5"}} + }); } SECTION("(Conv|ReLU)->Add") { const auto results = GraphMatching(g1).match("(Conv|ReLU)->Add"); - REQUIRE(results.size() == 2); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - REQUIRE(result.graph->getNodes().size() == 2); - } + checkMatches(results, { + {"conv4_conv", {"add", "conv4_conv"}}, + {"relu5", {"add2", "relu5"}} + }); } SECTION("Add<*-.") { const auto results = GraphMatching(g1).match("Add<*-."); - REQUIRE(results.size() == 2); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - REQUIRE(result.graph->getNodes().size() == 2); - } + checkMatches(results, { + {"add", {"add", "conv4_conv"}}, + {"add2", {"add2", "relu5"}} + }); } SECTION("(Add#<*-.)+") { const auto results = GraphMatching(g1).match("(Add#<*-.)+"); - REQUIRE(results.size() == 2); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - REQUIRE(result.graph->getNodes().size() == 3); - } + checkMatches(results, { + {"add", {"add", "conv4_conv", "relu3"}}, + {"add2", {"add2", "conv5_conv", "relu5"}} + }); } SECTION("Conv-*>(ReLU&Add)") { const auto results = GraphMatching(g1).match("Conv-*>(ReLU&Add)"); - REQUIRE(results.size() == 1); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - REQUIRE(result.graph->getNodes().size() == 3); - } + checkMatches(results, { + {"conv5_conv", {"add2", "conv5_conv", "relu5"}} + }); } SECTION("Conv->(ReLU&Add)") { @@ -162,12 +180,10 @@ TEST_CASE("[core/graph] Matching") { SECTION("ReLU-*>((Pad->Conv-*>Add#)&Add#)") { const auto results = GraphMatching(g1).match("ReLU-*>((Pad->Conv-*>Add#)&Add#)"); - REQUIRE(results.size() == 1); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - REQUIRE(result.graph->getNodes().size() == 4); - } + checkMatches(results, { + {"relu3", {"add", "conv4_conv", "conv4_pad", "relu3"}} + }); } SECTION("ReLU-*>((Pad->Conv-*>Add)&Add)") { @@ -177,18 +193,16 @@ TEST_CASE("[core/graph] Matching") { SECTION("Pad->Conv[3x3]->ReLU") { auto gm = GraphMatching(g1); - gm.addLambda("3x3", [](const NodePtr& node) { + gm.addNodeLambda("3x3", [](const NodePtr& node) { const std::shared_ptr<Conv_Op<2>> op = std::static_pointer_cast<Conv_Op<2>>(node->getOperator()); return (op->getAttr<std::array<DimSize_t, 2>>("KernelDims") == std::array<DimSize_t, 2>({3, 3})); }); const auto results = gm.match("Pad->Conv[3x3]->ReLU"); - REQUIRE(results.size() == 1); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - REQUIRE(result.graph->getNodes().size() == 3); - } + checkMatches(results, { + {"conv3_pad", {"conv3_conv", "conv3_pad", "relu3"}} + }); } } -- GitLab