Skip to content
Snippets Groups Projects
Commit 46d53493 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Changed default behavior to single output matching and added modifier to match multiple outputs

parent 422809c4
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!138Alternative graph matching
Pipeline #46074 failed
...@@ -33,6 +33,7 @@ public: ...@@ -33,6 +33,7 @@ public:
bool firstNode = true; bool firstNode = true;
bool inSequence = false; bool inSequence = false;
bool lookForChild = true; bool lookForChild = true;
bool singleOutput = true;
IOIndex_t edgeLeftIdx = 0; IOIndex_t edgeLeftIdx = 0;
IOIndex_t edgeRightIdx = 0; IOIndex_t edgeRightIdx = 0;
size_t depth = 0; size_t depth = 0;
...@@ -95,6 +96,15 @@ public: ...@@ -95,6 +96,15 @@ public:
* Note that the anchor is required since we intend to match several * Note that the anchor is required since we intend to match several
* inputs of the same node! * inputs of the same node!
* *
* - In Aidge, a node output can be connected to multiple other nodes. In
* your query, you can allow it or not, with the "-" or "~" modifier.
* EXAMPLE: "Conv->ReLU" will match the Conv that are **only** connected
* to a ReLU node at their output #0.
* "Conv~>ReLU" will match all the Conv connected to a ReLU even
* if they are also connected to other nodes at the same output #0.
* When implementing a match & replace recipe, beware that you don't break
* branches in the middle of your matching result if you use "~"!
*
* - The matching results can be overlapping, meaning that some nodes may be * - The matching results can be overlapping, meaning that some nodes may be
* found in multiple results. Some results may be subsets of other results. * found in multiple results. Some results may be subsets of other results.
* EXAMPLE: assume graph Conv#1->ReLU#1->Conv#2->ReLU#2 * EXAMPLE: assume graph Conv#1->ReLU#1->Conv#2->ReLU#2
...@@ -144,8 +154,8 @@ private: ...@@ -144,8 +154,8 @@ private:
/** /**
* IO_INDEX_ANY = '*' * IO_INDEX_ANY = '*'
* IO_INDEX = IO_INDEX_ANY | [0-9]+ * IO_INDEX = IO_INDEX_ANY | [0-9]+
* CHILD_EDGE = '-' (IO_INDEX '-')? IO_INDEX? '>' * CHILD_EDGE = ('-' | '~') (IO_INDEX '-')? IO_INDEX? '>'
* PARENT_EDGE = '<' (IO_INDEX '-')? IO_INDEX? '-' * PARENT_EDGE = '<' (IO_INDEX '-')? IO_INDEX? ('-' | '~')
* EDGE = CHILD_EDGE | PARENT_EDGE * EDGE = CHILD_EDGE | PARENT_EDGE
*/ */
bool matchEdge(Context& ctx, std::vector<MatchingResult>& matches); bool matchEdge(Context& ctx, std::vector<MatchingResult>& matches);
......
...@@ -326,7 +326,8 @@ bool Aidge::SinglePassGraphMatching::matchEdge(Context& ctx, std::vector<Matchin ...@@ -326,7 +326,8 @@ bool Aidge::SinglePassGraphMatching::matchEdge(Context& ctx, std::vector<Matchin
Log::debug("{}edge", std::string(2*newCtx.depth, ' ')); Log::debug("{}edge", std::string(2*newCtx.depth, ' '));
removeWhiteSpace(newCtx.query); removeWhiteSpace(newCtx.query);
if (!newCtx.query.empty() && newCtx.query[0] == '-') { if (!newCtx.query.empty() && (newCtx.query[0] == '-' || newCtx.query[0] == '~')) {
newCtx.singleOutput = (newCtx.query[0] == '-');
newCtx.query.erase(0, 1); // drop '-' newCtx.query.erase(0, 1); // drop '-'
newCtx.lookForChild = true; newCtx.lookForChild = true;
} }
...@@ -381,7 +382,8 @@ bool Aidge::SinglePassGraphMatching::matchEdge(Context& ctx, std::vector<Matchin ...@@ -381,7 +382,8 @@ bool Aidge::SinglePassGraphMatching::matchEdge(Context& ctx, std::vector<Matchin
if (newCtx.lookForChild && !newCtx.query.empty() && newCtx.query[0] == '>') { if (newCtx.lookForChild && !newCtx.query.empty() && newCtx.query[0] == '>') {
newCtx.query.erase(0, 1); // drop '>' newCtx.query.erase(0, 1); // drop '>'
} }
else if (!newCtx.lookForChild && !newCtx.query.empty() && newCtx.query[0] == '-') { else if (!newCtx.lookForChild && !newCtx.query.empty() && (newCtx.query[0] == '-' || newCtx.query[0] == '~')) {
newCtx.singleOutput = (newCtx.query[0] == '-');
newCtx.query.erase(0, 1); // drop '-' newCtx.query.erase(0, 1); // drop '-'
} }
else { else {
...@@ -523,6 +525,10 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin ...@@ -523,6 +525,10 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin
: newMatches[i].startNode->outputs(); : newMatches[i].startNode->outputs();
for (const auto& output : outputs) { for (const auto& output : outputs) {
if (newCtx.singleOutput && output.size() > 1) {
continue;
}
for (const auto& node : output) { for (const auto& node : output) {
if ((type.empty() || node.first->type() == type) if ((type.empty() || node.first->type() == type)
&& (lambda.empty() || mLambda.at(lambda)(node.first)) && (lambda.empty() || mLambda.at(lambda)(node.first))
...@@ -560,6 +566,10 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin ...@@ -560,6 +566,10 @@ bool Aidge::SinglePassGraphMatching::matchNode(Context& ctx, std::vector<Matchin
&& (lambda.empty() || mLambda.at(lambda)(input.first)) && (lambda.empty() || mLambda.at(lambda)(input.first))
&& (newCtx.edgeRightIdx == gk_IODefaultIndex || input.second == newCtx.edgeRightIdx)) && (newCtx.edgeRightIdx == gk_IODefaultIndex || input.second == newCtx.edgeRightIdx))
{ {
if (newCtx.singleOutput && input.first->getChildren(input.second).size() > 1) {
continue;
}
if (mGraph->inView(input.first) && !newMatches[i].graph->inView(input.first)) { if (mGraph->inView(input.first) && !newMatches[i].graph->inView(input.first)) {
newMatches[i].graph->add(input.first, false); newMatches[i].graph->add(input.first, false);
if (!anchor.empty()) { if (!anchor.empty()) {
......
...@@ -65,6 +65,30 @@ TEST_CASE("[core/graph] Matching") { ...@@ -65,6 +65,30 @@ TEST_CASE("[core/graph] Matching") {
SECTION("Conv->(ReLU->Pad->Conv)*") { SECTION("Conv->(ReLU->Pad->Conv)*") {
const auto results = SinglePassGraphMatching(g1).match("Conv->(ReLU->Pad->Conv)*"); const auto results = SinglePassGraphMatching(g1).match("Conv->(ReLU->Pad->Conv)*");
checkMatches(results, {
{"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "relu1", "relu2"}},
{"conv2_conv", {"conv2_conv", "conv3_conv", "conv3_pad", "relu2"}},
{"conv3_conv", {"conv3_conv"}},
{"conv4_conv", {"conv4_conv"}},
{"conv5_conv", {"conv5_conv"}}
});
}
SECTION("Conv->(ReLU~>Pad->Conv)*") {
const auto results = SinglePassGraphMatching(g1).match("Conv->(ReLU~>Pad->Conv)*");
checkMatches(results, {
{"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu1", "relu2", "relu3"}},
{"conv2_conv", {"conv2_conv", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu2", "relu3"}},
{"conv3_conv", {"conv3_conv", "conv4_conv", "conv4_pad", "relu3"}},
{"conv4_conv", {"conv4_conv"}},
{"conv5_conv", {"conv5_conv"}}
});
}
SECTION("Conv~>(ReLU~>Pad->Conv)*") {
const auto results = SinglePassGraphMatching(g1).match("Conv~>(ReLU~>Pad->Conv)*");
checkMatches(results, { checkMatches(results, {
{"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu1", "relu2", "relu3"}}, {"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu1", "relu2", "relu3"}},
{"conv2_conv", {"conv2_conv", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu2", "relu3"}}, {"conv2_conv", {"conv2_conv", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu2", "relu3"}},
...@@ -77,6 +101,25 @@ TEST_CASE("[core/graph] Matching") { ...@@ -77,6 +101,25 @@ TEST_CASE("[core/graph] Matching") {
SECTION("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer") { SECTION("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer") {
const auto results = SinglePassGraphMatching(g1).match("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer"); const auto results = SinglePassGraphMatching(g1).match("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer");
checkMatches(results, {
{"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
{"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}
});
}
SECTION("Pad->Conv#~>ReLU;Conv#<1-Producer;Conv#<2-Producer") {
const auto results = SinglePassGraphMatching(g1).match("Pad->Conv#~>ReLU;Conv#<1-Producer;Conv#<2-Producer");
checkMatches(results, {
{"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
{"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}},
{"conv5_pad", {"conv5_b", "conv5_conv", "conv5_pad", "conv5_w", "relu5"}}
});
}
SECTION("Pad->Conv#~>ReLU;(Conv#<*-Producer){2}") {
const auto results = SinglePassGraphMatching(g1).match("Pad->Conv#~>ReLU;(Conv#<*-Producer){2}");
checkMatches(results, { checkMatches(results, {
{"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}}, {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
{"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}, {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}},
...@@ -87,6 +130,15 @@ TEST_CASE("[core/graph] Matching") { ...@@ -87,6 +130,15 @@ TEST_CASE("[core/graph] Matching") {
SECTION("Pad->Conv#->ReLU;(Conv#<*-Producer){2}") { SECTION("Pad->Conv#->ReLU;(Conv#<*-Producer){2}") {
const auto results = SinglePassGraphMatching(g1).match("Pad->Conv#->ReLU;(Conv#<*-Producer){2}"); const auto results = SinglePassGraphMatching(g1).match("Pad->Conv#->ReLU;(Conv#<*-Producer){2}");
checkMatches(results, {
{"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
{"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}
});
}
SECTION("Pad->Conv#~>ReLU;(Conv#<*-.){2}") {
const auto results = SinglePassGraphMatching(g1).match("Pad->Conv#~>ReLU;(Conv#<*-.){2}");
checkMatches(results, { checkMatches(results, {
{"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}}, {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
{"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}, {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}},
...@@ -99,8 +151,19 @@ TEST_CASE("[core/graph] Matching") { ...@@ -99,8 +151,19 @@ TEST_CASE("[core/graph] Matching") {
checkMatches(results, { checkMatches(results, {
{"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}}, {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
{"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}, {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}
{"conv5_pad", {"conv5_b", "conv5_conv", "conv5_pad", "conv5_w", "relu5"}} });
}
SECTION("Conv#~>ReLU*;Conv#<-Pad*") {
const auto results = SinglePassGraphMatching(g1).match("Conv#~>ReLU*;Conv#<-Pad*");
checkMatches(results, {
{"conv1", {"conv1", "relu1"}},
{"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
{"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
{"conv4_conv", {"conv4_conv", "conv4_pad"}},
{"conv5_conv", {"conv5_conv", "conv5_pad", "relu5"}}
}); });
} }
...@@ -112,13 +175,37 @@ TEST_CASE("[core/graph] Matching") { ...@@ -112,13 +175,37 @@ TEST_CASE("[core/graph] Matching") {
{"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}}, {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
{"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}}, {"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
{"conv4_conv", {"conv4_conv", "conv4_pad"}}, {"conv4_conv", {"conv4_conv", "conv4_pad"}},
{"conv5_conv", {"conv5_conv", "conv5_pad", "relu5"}} {"conv5_conv", {"conv5_conv", "conv5_pad"}}
}); });
} }
SECTION("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?") { SECTION("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?") {
const auto results = SinglePassGraphMatching(g1).match("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?"); const auto results = SinglePassGraphMatching(g1).match("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?");
checkMatches(results, {
{"conv1", {"conv1", "relu1"}},
{"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
{"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
{"conv4_conv", {"add", "conv4_conv", "conv4_pad"}},
{"conv5_conv", {"conv5_conv", "conv5_pad"}}
});
}
SECTION("Conv#~>ReLU?-*>Add#1?~>ReLU?;Conv#<-Pad?;(Add#1<*-.)?") {
const auto results = SinglePassGraphMatching(g1).match("Conv#~>ReLU?-*>Add#1?~>ReLU?;Conv#<-Pad?;(Add#1<*-.)?");
checkMatches(results, {
{"conv1", {"conv1", "relu1"}},
{"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
{"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
{"conv4_conv", {"add", "conv4_conv", "conv4_pad"}},
{"conv5_conv", {"add2", "conv5_conv", "conv5_pad", "relu5"}}
});
}
SECTION("Conv#~>ReLU?~*>Add#1?~>ReLU?;Conv#<-Pad?;(Add#1<*~.)?") {
const auto results = SinglePassGraphMatching(g1).match("Conv#~>ReLU?~*>Add#1?~>ReLU?;Conv#<-Pad?;(Add#1<*~.)?");
checkMatches(results, { checkMatches(results, {
{"conv1", {"conv1", "relu1"}}, {"conv1", {"conv1", "relu1"}},
{"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}}, {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
...@@ -131,6 +218,18 @@ TEST_CASE("[core/graph] Matching") { ...@@ -131,6 +218,18 @@ TEST_CASE("[core/graph] Matching") {
SECTION("Conv#->ReLU?;Conv#<-Pad?") { SECTION("Conv#->ReLU?;Conv#<-Pad?") {
const auto results = SinglePassGraphMatching(g1).match("Conv#->ReLU?;Conv#<-Pad?"); const auto results = SinglePassGraphMatching(g1).match("Conv#->ReLU?;Conv#<-Pad?");
checkMatches(results, {
{"conv1", {"conv1", "relu1"}},
{"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
{"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
{"conv4_conv", {"conv4_conv", "conv4_pad"}},
{"conv5_conv", {"conv5_conv", "conv5_pad"}}
});
}
SECTION("Conv#~>ReLU?;Conv#<-Pad?") {
const auto results = SinglePassGraphMatching(g1).match("Conv#~>ReLU?;Conv#<-Pad?");
checkMatches(results, { checkMatches(results, {
{"conv1", {"conv1", "relu1"}}, {"conv1", {"conv1", "relu1"}},
{"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}}, {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
...@@ -158,8 +257,8 @@ TEST_CASE("[core/graph] Matching") { ...@@ -158,8 +257,8 @@ TEST_CASE("[core/graph] Matching") {
}); });
} }
SECTION("(Add#<*-.)+") { SECTION("(Add#<*~.)+") {
const auto results = SinglePassGraphMatching(g1).match("(Add#<*-.)+"); const auto results = SinglePassGraphMatching(g1).match("(Add#<*~.)+");
checkMatches(results, { checkMatches(results, {
{"add", {"add", "conv4_conv", "relu3"}}, {"add", {"add", "conv4_conv", "relu3"}},
...@@ -167,21 +266,21 @@ TEST_CASE("[core/graph] Matching") { ...@@ -167,21 +266,21 @@ TEST_CASE("[core/graph] Matching") {
}); });
} }
SECTION("Conv-*>(ReLU&Add)") { SECTION("Conv~*>(ReLU&Add)") {
const auto results = SinglePassGraphMatching(g1).match("Conv-*>(ReLU&Add)"); const auto results = SinglePassGraphMatching(g1).match("Conv~*>(ReLU&Add)");
checkMatches(results, { checkMatches(results, {
{"conv5_conv", {"add2", "conv5_conv", "relu5"}} {"conv5_conv", {"add2", "conv5_conv", "relu5"}}
}); });
} }
SECTION("Conv->(ReLU&Add)") { SECTION("Conv~>(ReLU&Add)") {
const auto results = SinglePassGraphMatching(g1).match("Conv->(ReLU&Add)"); const auto results = SinglePassGraphMatching(g1).match("Conv~>(ReLU&Add)");
REQUIRE(results.size() == 0); REQUIRE(results.size() == 0);
} }
SECTION("ReLU-*>((Pad->Conv-*>Add#)&Add#)") { SECTION("ReLU~*>((Pad->Conv-*>Add#)&Add#)") {
const auto results = SinglePassGraphMatching(g1).match("ReLU-*>((Pad->Conv-*>Add#)&Add#)"); const auto results = SinglePassGraphMatching(g1).match("ReLU~*>((Pad->Conv-*>Add#)&Add#)");
checkMatches(results, { checkMatches(results, {
{"relu3", {"add", "conv4_conv", "conv4_pad", "relu3"}} {"relu3", {"add", "conv4_conv", "conv4_pad", "relu3"}}
...@@ -223,22 +322,29 @@ TEST_CASE("[core/graph] Matching") { ...@@ -223,22 +322,29 @@ TEST_CASE("[core/graph] Matching") {
} }
SECTION("Conv->ReLU [perf]") { SECTION("Conv->ReLU [perf]") {
const size_t nbTests = 10; const size_t nbTests = 3;
std::mt19937::result_type seed(1); std::mt19937::result_type seed(1);
for (int test = 0; test < nbTests; ++test) { for (int test = 0; test < nbTests; ++test) {
RandomGraph randGraph; RandomGraph randGraph;
randGraph.types = {"Conv", "ReLU"}; randGraph.types = {"Conv", "ReLU", "Dummy"};
randGraph.typesWeights = {0.9, 0.1}; randGraph.typesWeights = {0.4, 0.4, 0.2};
randGraph.avgIn = 1;
randGraph.maxIn = 1;
randGraph.maxOut = 1;
randGraph.avgOut = 1;
randGraph.density = 0.9;
randGraph.acyclic = true;
const auto g1 = std::make_shared<GraphView>("g1"); const auto g1 = std::make_shared<GraphView>("g1");
Log::setConsoleLevel(Log::Warn); Log::setConsoleLevel(Log::Warn);
g1->add(randGraph.gen(seed, 100)); g1->add(randGraph.gen(seed, 100));
g1->save("graph_single_pass");
auto gm = SinglePassGraphMatching(g1); auto gm = SinglePassGraphMatching(g1);
const auto start = std::chrono::system_clock::now(); const auto start = std::chrono::system_clock::now();
const auto results = gm.match("Conv->ReLU"); const auto results = gm.match("Conv->ReLU#;ReLU#->Dummy");
const auto end = std::chrono::system_clock::now(); const auto end = std::chrono::system_clock::now();
const auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start); const auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment