Skip to content
Snippets Groups Projects
Commit 1c104764 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added lambda support

parent 3268e941
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!138Alternative graph matching
...@@ -99,8 +99,13 @@ public: ...@@ -99,8 +99,13 @@ public:
*/ */
std::vector<MatchingResult> match(const std::string& query); std::vector<MatchingResult> match(const std::string& query);
inline void addLambda(const std::string& name, bool(func)(const NodePtr&)) {
mLambda[name] = func;
}
private: private:
std::shared_ptr<GraphView> mGraph; std::shared_ptr<GraphView> mGraph;
std::map<std::string, bool(*)(const NodePtr&)> mLambda;
/** /**
* NODE_OR_BLOCK = BLOCK | NODE * NODE_OR_BLOCK = BLOCK | NODE
...@@ -140,7 +145,8 @@ private: ...@@ -140,7 +145,8 @@ private:
/** /**
* TYPE = [A-Za-z0-9_]+ * TYPE = [A-Za-z0-9_]+
* ANCHOR = [A-Za-z0-9_]+ * ANCHOR = [A-Za-z0-9_]+
* NODE = (TYPE | '.') ('#' ANCHOR)? * LAMBDA = [A-Za-z0-9_]+
* NODE = (TYPE | '.') ('#' ANCHOR)? ('[' LAMBDA ']')?
*/ */
bool matchNode(Context& ctx, std::vector<MatchingResult>& matches); bool matchNode(Context& ctx, std::vector<MatchingResult>& matches);
......
...@@ -454,10 +454,36 @@ bool Aidge::GraphMatching::matchNode(Context& ctx, std::vector<MatchingResult>& ...@@ -454,10 +454,36 @@ bool Aidge::GraphMatching::matchNode(Context& ctx, std::vector<MatchingResult>&
newCtx.query = newCtx.query.substr(endAnchor - newCtx.query.begin()); newCtx.query = newCtx.query.substr(endAnchor - newCtx.query.begin());
} }
std::string lambda = "";
if (!newCtx.query.empty() && newCtx.query[0] == '[') {
newCtx.query.erase(0, 1);
const auto endIdentifier = std::find_if(newCtx.query.begin(), newCtx.query.end(),
[](char c) { return (!isalnum(c) && c != '_'); });
if (endIdentifier == newCtx.query.begin()) {
Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
return false;
}
lambda = newCtx.query.substr(0, endIdentifier - newCtx.query.begin());
newCtx.query = newCtx.query.substr(endIdentifier - newCtx.query.begin());
if (!newCtx.query.empty() && newCtx.query[0] == ']') {
newCtx.query.erase(0, 1);
}
else {
Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
return false;
}
}
if (newCtx.firstSequence && newCtx.firstNode) { if (newCtx.firstSequence && newCtx.firstNode) {
// First node of first sequence = root node // First node of first sequence = root node
for (auto node : mGraph->getNodes()) { for (auto node : mGraph->getNodes()) {
if (type.empty() || node->type() == type) { if ((type.empty() || node->type() == type)
&& (lambda.empty() || mLambda.at(lambda)(node)))
{
MatchingResult result; MatchingResult result;
result.graph->add(node, false); result.graph->add(node, false);
if (!anchor.empty()) { if (!anchor.empty()) {
...@@ -499,6 +525,7 @@ bool Aidge::GraphMatching::matchNode(Context& ctx, std::vector<MatchingResult>& ...@@ -499,6 +525,7 @@ bool Aidge::GraphMatching::matchNode(Context& ctx, std::vector<MatchingResult>&
for (const auto& output : outputs) { for (const auto& output : outputs) {
for (const auto& node : output) { for (const auto& node : output) {
if ((type.empty() || node.first->type() == type) if ((type.empty() || node.first->type() == type)
&& (lambda.empty() || mLambda.at(lambda)(node.first))
&& (newCtx.edgeRightIdx == gk_IODefaultIndex || node.second == newCtx.edgeRightIdx)) && (newCtx.edgeRightIdx == gk_IODefaultIndex || node.second == newCtx.edgeRightIdx))
{ {
if (mGraph->inView(node.first) && !newMatches[i].graph->inView(node.first)) { if (mGraph->inView(node.first) && !newMatches[i].graph->inView(node.first)) {
...@@ -530,6 +557,7 @@ bool Aidge::GraphMatching::matchNode(Context& ctx, std::vector<MatchingResult>& ...@@ -530,6 +557,7 @@ bool Aidge::GraphMatching::matchNode(Context& ctx, std::vector<MatchingResult>&
for (const auto& input : inputs) { for (const auto& input : inputs) {
if ((type.empty() || input.first->type() == type) if ((type.empty() || input.first->type() == type)
&& (lambda.empty() || mLambda.at(lambda)(input.first))
&& (newCtx.edgeRightIdx == gk_IODefaultIndex || input.second == newCtx.edgeRightIdx)) && (newCtx.edgeRightIdx == gk_IODefaultIndex || input.second == newCtx.edgeRightIdx))
{ {
if (mGraph->inView(input.first) && !newMatches[i].graph->inView(input.first)) { if (mGraph->inView(input.first) && !newMatches[i].graph->inView(input.first)) {
......
...@@ -30,7 +30,7 @@ TEST_CASE("[core/graph] Matching") { ...@@ -30,7 +30,7 @@ TEST_CASE("[core/graph] Matching") {
ReLU("relu1"), ReLU("relu1"),
PaddedConv(4, 8, {5, 5}, "conv2", {1, 1}, {2, 2, 2, 2}), PaddedConv(4, 8, {5, 5}, "conv2", {1, 1}, {2, 2, 2, 2}),
ReLU("relu2"), ReLU("relu2"),
PaddedConv(8, 16, {5, 5}, "conv3", {1, 1}, {2, 2, 2, 2}), PaddedConv(8, 16, {3, 3}, "conv3", {1, 1}, {2, 2, 2, 2}),
ReLU("relu3"), ReLU("relu3"),
PaddedConv(8, 16, {5, 5}, "conv4", {1, 1}, {2, 2, 2, 2}), PaddedConv(8, 16, {5, 5}, "conv4", {1, 1}, {2, 2, 2, 2}),
Add(2, "add"), Add(2, "add"),
...@@ -174,4 +174,21 @@ TEST_CASE("[core/graph] Matching") { ...@@ -174,4 +174,21 @@ TEST_CASE("[core/graph] Matching") {
const auto results = GraphMatching(g1).match("ReLU-*>((Pad->Conv-*>Add)&Add)"); const auto results = GraphMatching(g1).match("ReLU-*>((Pad->Conv-*>Add)&Add)");
REQUIRE(results.size() == 0); REQUIRE(results.size() == 0);
} }
SECTION("Pad->Conv[3x3]->ReLU") {
auto gm = GraphMatching(g1);
gm.addLambda("3x3", [](const NodePtr& node) {
const std::shared_ptr<Conv_Op<2>> op =
std::static_pointer_cast<Conv_Op<2>>(node->getOperator());
return (op->getAttr<std::array<DimSize_t, 2>>("KernelDims") == std::array<DimSize_t, 2>({3, 3}));
});
const auto results = gm.match("Pad->Conv[3x3]->ReLU");
REQUIRE(results.size() == 1);
for (const auto& result : results) {
fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
REQUIRE(result.graph->getNodes().size() == 3);
}
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment