diff --git a/include/aidge/graph/Matching.hpp b/include/aidge/graph/Matching.hpp index 42bb0b6efecff49a6c70e8cc3068699d35c453a0..004011481f6f72e5176c6a00634d34e9569c65ad 100644 --- a/include/aidge/graph/Matching.hpp +++ b/include/aidge/graph/Matching.hpp @@ -33,6 +33,7 @@ public: bool firstNode = true; bool inSequence = false; bool lookForChild = true; + bool singleOutput = true; IOIndex_t edgeLeftIdx = 0; IOIndex_t edgeRightIdx = 0; size_t depth = 0; @@ -95,6 +96,15 @@ public: * Note that the anchor is required since we intend to match several * inputs of the same node! * + * - In Aidge, a node output can be connected to multiple other nodes. In + * your query, you can allow it or not, with the "-" or "~" modifier. + * EXAMPLE: "Conv->ReLU" will match the Conv that are **only** connected + * to a ReLU node at their output #0. + * "Conv~>ReLU" will match all the Conv connected to a ReLU even + * if they are also connected to other nodes at the same output #0. + * When implementing a match & replace recipe, beware that you don't break + * branches in the middle of your matching result if you use "~"! + * * - The matching results can be overlapping, meaning that some nodes may be * found in multiple results. Some results may be subsets of other results. * EXAMPLE: assume graph Conv#1->ReLU#1->Conv#2->ReLU#2 @@ -144,8 +154,8 @@ private: /** * IO_INDEX_ANY = '*' * IO_INDEX = IO_INDEX_ANY | [0-9]+ - * CHILD_EDGE = '-' (IO_INDEX '-')? IO_INDEX? '>' - * PARENT_EDGE = '<' (IO_INDEX '-')? IO_INDEX? '-' + * CHILD_EDGE = ('-' | '~') (IO_INDEX '-')? IO_INDEX? '>' + * PARENT_EDGE = '<' (IO_INDEX '-')? IO_INDEX? ('-' | '~') * EDGE = CHILD_EDGE | PARENT_EDGE */ bool matchEdge(Context& ctx, std::vector<MatchingResult>& matches); diff --git a/src/graph/Matching.cpp b/src/graph/Matching.cpp index 3b1c4806a397eb164810e959e5edc211fbaa57a5..aae3175258760375782363be620add84e2abcee3 100644 --- a/src/graph/Matching.cpp +++ b/src/graph/Matching.cpp @@ -326,7 +326,8 @@ bool Aidge::SinglePassGraphMatching::matchEdge(Context& ctx, std::vector<Matchin Log::debug("{}edge", std::string(2*newCtx.depth, ' ')); removeWhiteSpace(newCtx.query); - if (!newCtx.query.empty() && newCtx.query[0] == '-') { + if (!newCtx.query.empty() && (newCtx.query[0] == '-' || newCtx.query[0] == '~')) { + newCtx.singleOutput = (newCtx.query[0] == '-'); newCtx.query.erase(0, 1); // drop '-' newCtx.lookForChild = true; } @@ -381,7 +382,8 @@ bool Aidge::SinglePassGraphMatching::matchEdge(Context& ctx, std::vector<Matchin if (newCtx.lookForChild && !newCtx.query.empty() && newCtx.query[0] == '>') { newCtx.query.erase(0, 1); // drop '>' } - else if (!newCtx.lookForChild && !newCtx.query.empty() && newCtx.query[0] == '-') { + else if (!newCtx.lookForChild && !newCtx.query.empty() && (newCtx.query[0] == '-' || newCtx.query[0] == '~')) { + newCtx.singleOutput = (newCtx.query[0] == '-'); newCtx.query.erase(0, 1); // drop '-' } else { @@ -523,6 +525,10 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin : newMatches[i].startNode->outputs(); for (const auto& output : outputs) { + if (newCtx.singleOutput && output.size() > 1) { + continue; + } + for (const auto& node : output) { if ((type.empty() || node.first->type() == type) && (lambda.empty() || mLambda.at(lambda)(node.first)) @@ -560,6 +566,10 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin && (lambda.empty() || mLambda.at(lambda)(input.first)) && (newCtx.edgeRightIdx == gk_IODefaultIndex || input.second == newCtx.edgeRightIdx)) { + if (newCtx.singleOutput && input.first->getChildren(input.second).size() > 1) { + continue; + } + if (mGraph->inView(input.first) && !newMatches[i].graph->inView(input.first)) { newMatches[i].graph->add(input.first, false); if (!anchor.empty()) { diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp index bc831c7b9238455f96800700292af90b58ac0fa1..c9d7faa8159949d0517331591b5f26200269dc11 100644 --- a/unit_tests/graph/Test_Matching.cpp +++ b/unit_tests/graph/Test_Matching.cpp @@ -65,6 +65,30 @@ TEST_CASE("[core/graph] Matching") { SECTION("Conv->(ReLU->Pad->Conv)*") { const auto results = SinglePassGraphMatching(g1).match("Conv->(ReLU->Pad->Conv)*"); + checkMatches(results, { + {"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "relu1", "relu2"}}, + {"conv2_conv", {"conv2_conv", "conv3_conv", "conv3_pad", "relu2"}}, + {"conv3_conv", {"conv3_conv"}}, + {"conv4_conv", {"conv4_conv"}}, + {"conv5_conv", {"conv5_conv"}} + }); + } + + SECTION("Conv->(ReLU~>Pad->Conv)*") { + const auto results = SinglePassGraphMatching(g1).match("Conv->(ReLU~>Pad->Conv)*"); + + checkMatches(results, { + {"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu1", "relu2", "relu3"}}, + {"conv2_conv", {"conv2_conv", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu2", "relu3"}}, + {"conv3_conv", {"conv3_conv", "conv4_conv", "conv4_pad", "relu3"}}, + {"conv4_conv", {"conv4_conv"}}, + {"conv5_conv", {"conv5_conv"}} + }); + } + + SECTION("Conv~>(ReLU~>Pad->Conv)*") { + const auto results = SinglePassGraphMatching(g1).match("Conv~>(ReLU~>Pad->Conv)*"); + checkMatches(results, { {"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu1", "relu2", "relu3"}}, {"conv2_conv", {"conv2_conv", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu2", "relu3"}}, @@ -77,6 +101,25 @@ TEST_CASE("[core/graph] Matching") { SECTION("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer") { const auto results = SinglePassGraphMatching(g1).match("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer"); + checkMatches(results, { + {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}}, + {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}} + }); + } + + SECTION("Pad->Conv#~>ReLU;Conv#<1-Producer;Conv#<2-Producer") { + const auto results = SinglePassGraphMatching(g1).match("Pad->Conv#~>ReLU;Conv#<1-Producer;Conv#<2-Producer"); + + checkMatches(results, { + {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}}, + {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}, + {"conv5_pad", {"conv5_b", "conv5_conv", "conv5_pad", "conv5_w", "relu5"}} + }); + } + + SECTION("Pad->Conv#~>ReLU;(Conv#<*-Producer){2}") { + const auto results = SinglePassGraphMatching(g1).match("Pad->Conv#~>ReLU;(Conv#<*-Producer){2}"); + checkMatches(results, { {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}}, {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}, @@ -87,6 +130,15 @@ TEST_CASE("[core/graph] Matching") { SECTION("Pad->Conv#->ReLU;(Conv#<*-Producer){2}") { const auto results = SinglePassGraphMatching(g1).match("Pad->Conv#->ReLU;(Conv#<*-Producer){2}"); + checkMatches(results, { + {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}}, + {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}} + }); + } + + SECTION("Pad->Conv#~>ReLU;(Conv#<*-.){2}") { + const auto results = SinglePassGraphMatching(g1).match("Pad->Conv#~>ReLU;(Conv#<*-.){2}"); + checkMatches(results, { {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}}, {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}, @@ -99,8 +151,19 @@ TEST_CASE("[core/graph] Matching") { checkMatches(results, { {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}}, - {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}, - {"conv5_pad", {"conv5_b", "conv5_conv", "conv5_pad", "conv5_w", "relu5"}} + {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}} + }); + } + + SECTION("Conv#~>ReLU*;Conv#<-Pad*") { + const auto results = SinglePassGraphMatching(g1).match("Conv#~>ReLU*;Conv#<-Pad*"); + + checkMatches(results, { + {"conv1", {"conv1", "relu1"}}, + {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}}, + {"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}}, + {"conv4_conv", {"conv4_conv", "conv4_pad"}}, + {"conv5_conv", {"conv5_conv", "conv5_pad", "relu5"}} }); } @@ -112,13 +175,37 @@ TEST_CASE("[core/graph] Matching") { {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}}, {"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}}, {"conv4_conv", {"conv4_conv", "conv4_pad"}}, - {"conv5_conv", {"conv5_conv", "conv5_pad", "relu5"}} + {"conv5_conv", {"conv5_conv", "conv5_pad"}} }); } SECTION("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?") { const auto results = SinglePassGraphMatching(g1).match("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?"); + checkMatches(results, { + {"conv1", {"conv1", "relu1"}}, + {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}}, + {"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}}, + {"conv4_conv", {"add", "conv4_conv", "conv4_pad"}}, + {"conv5_conv", {"conv5_conv", "conv5_pad"}} + }); + } + + SECTION("Conv#~>ReLU?-*>Add#1?~>ReLU?;Conv#<-Pad?;(Add#1<*-.)?") { + const auto results = SinglePassGraphMatching(g1).match("Conv#~>ReLU?-*>Add#1?~>ReLU?;Conv#<-Pad?;(Add#1<*-.)?"); + + checkMatches(results, { + {"conv1", {"conv1", "relu1"}}, + {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}}, + {"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}}, + {"conv4_conv", {"add", "conv4_conv", "conv4_pad"}}, + {"conv5_conv", {"add2", "conv5_conv", "conv5_pad", "relu5"}} + }); + } + + SECTION("Conv#~>ReLU?~*>Add#1?~>ReLU?;Conv#<-Pad?;(Add#1<*~.)?") { + const auto results = SinglePassGraphMatching(g1).match("Conv#~>ReLU?~*>Add#1?~>ReLU?;Conv#<-Pad?;(Add#1<*~.)?"); + checkMatches(results, { {"conv1", {"conv1", "relu1"}}, {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}}, @@ -131,6 +218,18 @@ TEST_CASE("[core/graph] Matching") { SECTION("Conv#->ReLU?;Conv#<-Pad?") { const auto results = SinglePassGraphMatching(g1).match("Conv#->ReLU?;Conv#<-Pad?"); + checkMatches(results, { + {"conv1", {"conv1", "relu1"}}, + {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}}, + {"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}}, + {"conv4_conv", {"conv4_conv", "conv4_pad"}}, + {"conv5_conv", {"conv5_conv", "conv5_pad"}} + }); + } + + SECTION("Conv#~>ReLU?;Conv#<-Pad?") { + const auto results = SinglePassGraphMatching(g1).match("Conv#~>ReLU?;Conv#<-Pad?"); + checkMatches(results, { {"conv1", {"conv1", "relu1"}}, {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}}, @@ -158,8 +257,8 @@ TEST_CASE("[core/graph] Matching") { }); } - SECTION("(Add#<*-.)+") { - const auto results = SinglePassGraphMatching(g1).match("(Add#<*-.)+"); + SECTION("(Add#<*~.)+") { + const auto results = SinglePassGraphMatching(g1).match("(Add#<*~.)+"); checkMatches(results, { {"add", {"add", "conv4_conv", "relu3"}}, @@ -167,21 +266,21 @@ TEST_CASE("[core/graph] Matching") { }); } - SECTION("Conv-*>(ReLU&Add)") { - const auto results = SinglePassGraphMatching(g1).match("Conv-*>(ReLU&Add)"); + SECTION("Conv~*>(ReLU&Add)") { + const auto results = SinglePassGraphMatching(g1).match("Conv~*>(ReLU&Add)"); checkMatches(results, { {"conv5_conv", {"add2", "conv5_conv", "relu5"}} }); } - SECTION("Conv->(ReLU&Add)") { - const auto results = SinglePassGraphMatching(g1).match("Conv->(ReLU&Add)"); + SECTION("Conv~>(ReLU&Add)") { + const auto results = SinglePassGraphMatching(g1).match("Conv~>(ReLU&Add)"); REQUIRE(results.size() == 0); } - SECTION("ReLU-*>((Pad->Conv-*>Add#)&Add#)") { - const auto results = SinglePassGraphMatching(g1).match("ReLU-*>((Pad->Conv-*>Add#)&Add#)"); + SECTION("ReLU~*>((Pad->Conv-*>Add#)&Add#)") { + const auto results = SinglePassGraphMatching(g1).match("ReLU~*>((Pad->Conv-*>Add#)&Add#)"); checkMatches(results, { {"relu3", {"add", "conv4_conv", "conv4_pad", "relu3"}} @@ -223,22 +322,29 @@ TEST_CASE("[core/graph] Matching") { } SECTION("Conv->ReLU [perf]") { - const size_t nbTests = 10; + const size_t nbTests = 3; std::mt19937::result_type seed(1); for (int test = 0; test < nbTests; ++test) { RandomGraph randGraph; - randGraph.types = {"Conv", "ReLU"}; - randGraph.typesWeights = {0.9, 0.1}; + randGraph.types = {"Conv", "ReLU", "Dummy"}; + randGraph.typesWeights = {0.4, 0.4, 0.2}; + randGraph.avgIn = 1; + randGraph.maxIn = 1; + randGraph.maxOut = 1; + randGraph.avgOut = 1; + randGraph.density = 0.9; + randGraph.acyclic = true; const auto g1 = std::make_shared<GraphView>("g1"); Log::setConsoleLevel(Log::Warn); g1->add(randGraph.gen(seed, 100)); + g1->save("graph_single_pass"); auto gm = SinglePassGraphMatching(g1); const auto start = std::chrono::system_clock::now(); - const auto results = gm.match("Conv->ReLU"); + const auto results = gm.match("Conv->ReLU#;ReLU#->Dummy"); const auto end = std::chrono::system_clock::now(); const auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);