From 1c104764991780527cb2103899a3e13800e60932 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Wed, 15 May 2024 18:29:31 +0200 Subject: [PATCH] Added lambda support --- include/aidge/graph/Matching.hpp | 8 +++++++- src/graph/Matching.cpp | 30 +++++++++++++++++++++++++++++- unit_tests/graph/Test_Matching.cpp | 19 ++++++++++++++++++- 3 files changed, 54 insertions(+), 3 deletions(-) diff --git a/include/aidge/graph/Matching.hpp b/include/aidge/graph/Matching.hpp index 4b4f9a156..8a75a8539 100644 --- a/include/aidge/graph/Matching.hpp +++ b/include/aidge/graph/Matching.hpp @@ -99,8 +99,13 @@ public: */ std::vector<MatchingResult> match(const std::string& query); + inline void addLambda(const std::string& name, bool(func)(const NodePtr&)) { + mLambda[name] = func; + } + private: std::shared_ptr<GraphView> mGraph; + std::map<std::string, bool(*)(const NodePtr&)> mLambda; /** * NODE_OR_BLOCK = BLOCK | NODE @@ -140,7 +145,8 @@ private: /** * TYPE = [A-Za-z0-9_]+ * ANCHOR = [A-Za-z0-9_]+ - * NODE = (TYPE | '.') ('#' ANCHOR)? + * LAMBDA = [A-Za-z0-9_]+ + * NODE = (TYPE | '.') ('#' ANCHOR)? ('[' LAMBDA ']')? */ bool matchNode(Context& ctx, std::vector<MatchingResult>& matches); diff --git a/src/graph/Matching.cpp b/src/graph/Matching.cpp index d0397fa64..28d4fa3c7 100644 --- a/src/graph/Matching.cpp +++ b/src/graph/Matching.cpp @@ -454,10 +454,36 @@ bool Aidge::GraphMatching::matchNode(Context& ctx, std::vector<MatchingResult>& newCtx.query = newCtx.query.substr(endAnchor - newCtx.query.begin()); } + std::string lambda = ""; + if (!newCtx.query.empty() && newCtx.query[0] == '[') { + newCtx.query.erase(0, 1); + + const auto endIdentifier = std::find_if(newCtx.query.begin(), newCtx.query.end(), + [](char c) { return (!isalnum(c) && c != '_'); }); + + if (endIdentifier == newCtx.query.begin()) { + Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red))); + return false; + } + + lambda = newCtx.query.substr(0, endIdentifier - newCtx.query.begin()); + newCtx.query = newCtx.query.substr(endIdentifier - newCtx.query.begin()); + + if (!newCtx.query.empty() && newCtx.query[0] == ']') { + newCtx.query.erase(0, 1); + } + else { + Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red))); + return false; + } + } + if (newCtx.firstSequence && newCtx.firstNode) { // First node of first sequence = root node for (auto node : mGraph->getNodes()) { - if (type.empty() || node->type() == type) { + if ((type.empty() || node->type() == type) + && (lambda.empty() || mLambda.at(lambda)(node))) + { MatchingResult result; result.graph->add(node, false); if (!anchor.empty()) { @@ -499,6 +525,7 @@ bool Aidge::GraphMatching::matchNode(Context& ctx, std::vector<MatchingResult>& for (const auto& output : outputs) { for (const auto& node : output) { if ((type.empty() || node.first->type() == type) + && (lambda.empty() || mLambda.at(lambda)(node.first)) && (newCtx.edgeRightIdx == gk_IODefaultIndex || node.second == newCtx.edgeRightIdx)) { if (mGraph->inView(node.first) && !newMatches[i].graph->inView(node.first)) { @@ -530,6 +557,7 @@ bool Aidge::GraphMatching::matchNode(Context& ctx, std::vector<MatchingResult>& for (const auto& input : inputs) { if ((type.empty() || input.first->type() == type) + && (lambda.empty() || mLambda.at(lambda)(input.first)) && (newCtx.edgeRightIdx == gk_IODefaultIndex || input.second == newCtx.edgeRightIdx)) { if (mGraph->inView(input.first) && !newMatches[i].graph->inView(input.first)) { diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp index 10d2da45a..c1d84f822 100644 --- a/unit_tests/graph/Test_Matching.cpp +++ b/unit_tests/graph/Test_Matching.cpp @@ -30,7 +30,7 @@ TEST_CASE("[core/graph] Matching") { ReLU("relu1"), PaddedConv(4, 8, {5, 5}, "conv2", {1, 1}, {2, 2, 2, 2}), ReLU("relu2"), - PaddedConv(8, 16, {5, 5}, "conv3", {1, 1}, {2, 2, 2, 2}), + PaddedConv(8, 16, {3, 3}, "conv3", {1, 1}, {2, 2, 2, 2}), ReLU("relu3"), PaddedConv(8, 16, {5, 5}, "conv4", {1, 1}, {2, 2, 2, 2}), Add(2, "add"), @@ -174,4 +174,21 @@ TEST_CASE("[core/graph] Matching") { const auto results = GraphMatching(g1).match("ReLU-*>((Pad->Conv-*>Add)&Add)"); REQUIRE(results.size() == 0); } + + SECTION("Pad->Conv[3x3]->ReLU") { + auto gm = GraphMatching(g1); + gm.addLambda("3x3", [](const NodePtr& node) { + const std::shared_ptr<Conv_Op<2>> op = + std::static_pointer_cast<Conv_Op<2>>(node->getOperator()); + return (op->getAttr<std::array<DimSize_t, 2>>("KernelDims") == std::array<DimSize_t, 2>({3, 3})); + }); + + const auto results = gm.match("Pad->Conv[3x3]->ReLU"); + REQUIRE(results.size() == 1); + + for (const auto& result : results) { + fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); + REQUIRE(result.graph->getNodes().size() == 3); + } + } } -- GitLab