diff --git a/include/aidge/graph/Matching.hpp b/include/aidge/graph/Matching.hpp index 8a75a8539e931dcec8db14bb4fecf08b12a7a82a..983967fb9d02c27f5769585160cc331ee943cd36 100644 --- a/include/aidge/graph/Matching.hpp +++ b/include/aidge/graph/Matching.hpp @@ -99,7 +99,7 @@ public: */ std::vector<MatchingResult> match(const std::string& query); - inline void addLambda(const std::string& name, bool(func)(const NodePtr&)) { + inline void addNodeLambda(const std::string& name, bool(func)(const NodePtr&)) { mLambda[name] = func; } diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp index c1d84f8220dcd002a33931deb21f51de6bafd7ce..6c5abcc4725d7f27292970ab8b2746227775751b 100644 --- a/unit_tests/graph/Test_Matching.cpp +++ b/unit_tests/graph/Test_Matching.cpp @@ -23,6 +23,19 @@ using namespace Aidge; +void checkMatches(const std::vector<GraphMatching::MatchingResult>& results, const std::map<std::string, std::set<std::string>>& expected) { + REQUIRE(results.size() == expected.size()); + + for (const auto& result : results) { + const auto found = nodePtrTo(result.graph->getNodes(), nodePtrToName); + fmt::print("Found: {}\n", found); + + const auto rootNode = result.graph->rootNode()->name(); + const auto expectedSet = expected.at(rootNode); + REQUIRE(found == expectedSet); + } +} + TEST_CASE("[core/graph] Matching") { auto g1 = Sequential({ Producer({16, 3, 512, 512}, "dataProvider"), @@ -49,110 +62,115 @@ TEST_CASE("[core/graph] Matching") { SECTION("Conv->(ReLU->Pad->Conv)*") { const auto results = GraphMatching(g1).match("Conv->(ReLU->Pad->Conv)*"); - REQUIRE(results.size() == 5); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - } + checkMatches(results, { + {"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu1", "relu2", "relu3"}}, + {"conv2_conv", {"conv2_conv", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu2", "relu3"}}, + {"conv3_conv", {"conv3_conv", "conv4_conv", "conv4_pad", "relu3"}}, + {"conv4_conv", {"conv4_conv"}}, + {"conv5_conv", {"conv5_conv"}} + }); } SECTION("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer") { const auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer"); - REQUIRE(results.size() == 3); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - REQUIRE(result.graph->getNodes().size() == 5); - } + checkMatches(results, { + {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}}, + {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}, + {"conv5_pad", {"conv5_b", "conv5_conv", "conv5_pad", "conv5_w", "relu5"}} + }); } SECTION("Pad->Conv#->ReLU;(Conv#<*-Producer){2}") { const auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;(Conv#<*-Producer){2}"); - REQUIRE(results.size() == 3); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - REQUIRE(result.graph->getNodes().size() == 5); - } + checkMatches(results, { + {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}}, + {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}, + {"conv5_pad", {"conv5_b", "conv5_conv", "conv5_pad", "conv5_w", "relu5"}} + }); } SECTION("Pad->Conv#->ReLU;(Conv#<*-.){2}") { const auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;(Conv#<*-.){2}"); - REQUIRE(results.size() == 3); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - REQUIRE(result.graph->getNodes().size() == 5); - } + checkMatches(results, { + {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}}, + {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}, + {"conv5_pad", {"conv5_b", "conv5_conv", "conv5_pad", "conv5_w", "relu5"}} + }); } SECTION("Conv#->ReLU*;Conv#<-Pad*") { const auto results = GraphMatching(g1).match("Conv#->ReLU*;Conv#<-Pad*"); - REQUIRE(results.size() == 5); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - REQUIRE((result.graph->getNodes().size() == 2 || result.graph->getNodes().size() == 3)); - } + checkMatches(results, { + {"conv1", {"conv1", "relu1"}}, + {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}}, + {"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}}, + {"conv4_conv", {"conv4_conv", "conv4_pad"}}, + {"conv5_conv", {"conv5_conv", "conv5_pad", "relu5"}} + }); } SECTION("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?") { const auto results = GraphMatching(g1).match("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?"); - REQUIRE(results.size() == 5); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - } + checkMatches(results, { + {"conv1", {"conv1", "relu1"}}, + {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}}, + {"conv3_conv", {"add", "conv3_conv", "conv3_pad", "conv4_conv", "relu3"}}, + {"conv4_conv", {"add", "conv4_conv", "conv4_pad", "relu3"}}, + {"conv5_conv", {"add2", "conv5_conv", "conv5_pad", "relu5"}} + }); } SECTION("Conv#->ReLU?;Conv#<-Pad?") { const auto results = GraphMatching(g1).match("Conv#->ReLU?;Conv#<-Pad?"); - REQUIRE(results.size() == 5); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - REQUIRE((result.graph->getNodes().size() == 2 || result.graph->getNodes().size() == 3)); - } + checkMatches(results, { + {"conv1", {"conv1", "relu1"}}, + {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}}, + {"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}}, + {"conv4_conv", {"conv4_conv", "conv4_pad"}}, + {"conv5_conv", {"conv5_conv", "conv5_pad", "relu5"}} + }); } SECTION("(Conv|ReLU)->Add") { const auto results = GraphMatching(g1).match("(Conv|ReLU)->Add"); - REQUIRE(results.size() == 2); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - REQUIRE(result.graph->getNodes().size() == 2); - } + checkMatches(results, { + {"conv4_conv", {"add", "conv4_conv"}}, + {"relu5", {"add2", "relu5"}} + }); } SECTION("Add<*-.") { const auto results = GraphMatching(g1).match("Add<*-."); - REQUIRE(results.size() == 2); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - REQUIRE(result.graph->getNodes().size() == 2); - } + checkMatches(results, { + {"add", {"add", "conv4_conv"}}, + {"add2", {"add2", "relu5"}} + }); } SECTION("(Add#<*-.)+") { const auto results = GraphMatching(g1).match("(Add#<*-.)+"); - REQUIRE(results.size() == 2); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - REQUIRE(result.graph->getNodes().size() == 3); - } + checkMatches(results, { + {"add", {"add", "conv4_conv", "relu3"}}, + {"add2", {"add2", "conv5_conv", "relu5"}} + }); } SECTION("Conv-*>(ReLU&Add)") { const auto results = GraphMatching(g1).match("Conv-*>(ReLU&Add)"); - REQUIRE(results.size() == 1); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - REQUIRE(result.graph->getNodes().size() == 3); - } + checkMatches(results, { + {"conv5_conv", {"add2", "conv5_conv", "relu5"}} + }); } SECTION("Conv->(ReLU&Add)") { @@ -162,12 +180,10 @@ TEST_CASE("[core/graph] Matching") { SECTION("ReLU-*>((Pad->Conv-*>Add#)&Add#)") { const auto results = GraphMatching(g1).match("ReLU-*>((Pad->Conv-*>Add#)&Add#)"); - REQUIRE(results.size() == 1); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - REQUIRE(result.graph->getNodes().size() == 4); - } + checkMatches(results, { + {"relu3", {"add", "conv4_conv", "conv4_pad", "relu3"}} + }); } SECTION("ReLU-*>((Pad->Conv-*>Add)&Add)") { @@ -177,18 +193,16 @@ TEST_CASE("[core/graph] Matching") { SECTION("Pad->Conv[3x3]->ReLU") { auto gm = GraphMatching(g1); - gm.addLambda("3x3", [](const NodePtr& node) { + gm.addNodeLambda("3x3", [](const NodePtr& node) { const std::shared_ptr<Conv_Op<2>> op = std::static_pointer_cast<Conv_Op<2>>(node->getOperator()); return (op->getAttr<std::array<DimSize_t, 2>>("KernelDims") == std::array<DimSize_t, 2>({3, 3})); }); const auto results = gm.match("Pad->Conv[3x3]->ReLU"); - REQUIRE(results.size() == 1); - for (const auto& result : results) { - fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); - REQUIRE(result.graph->getNodes().size() == 3); - } + checkMatches(results, { + {"conv3_pad", {"conv3_conv", "conv3_pad", "relu3"}} + }); } }