Skip to content
Snippets Groups Projects
Commit 1c104764 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added lambda support

parent 3268e941
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!138Alternative graph matching
......@@ -99,8 +99,13 @@ public:
*/
std::vector<MatchingResult> match(const std::string& query);
inline void addLambda(const std::string& name, bool(func)(const NodePtr&)) {
mLambda[name] = func;
}
private:
std::shared_ptr<GraphView> mGraph;
std::map<std::string, bool(*)(const NodePtr&)> mLambda;
/**
* NODE_OR_BLOCK = BLOCK | NODE
......@@ -140,7 +145,8 @@ private:
/**
* TYPE = [A-Za-z0-9_]+
* ANCHOR = [A-Za-z0-9_]+
* NODE = (TYPE | '.') ('#' ANCHOR)?
* LAMBDA = [A-Za-z0-9_]+
* NODE = (TYPE | '.') ('#' ANCHOR)? ('[' LAMBDA ']')?
*/
bool matchNode(Context& ctx, std::vector<MatchingResult>& matches);
......
......@@ -454,10 +454,36 @@ bool Aidge::GraphMatching::matchNode(Context& ctx, std::vector<MatchingResult>&
newCtx.query = newCtx.query.substr(endAnchor - newCtx.query.begin());
}
std::string lambda = "";
if (!newCtx.query.empty() && newCtx.query[0] == '[') {
newCtx.query.erase(0, 1);
const auto endIdentifier = std::find_if(newCtx.query.begin(), newCtx.query.end(),
[](char c) { return (!isalnum(c) && c != '_'); });
if (endIdentifier == newCtx.query.begin()) {
Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
return false;
}
lambda = newCtx.query.substr(0, endIdentifier - newCtx.query.begin());
newCtx.query = newCtx.query.substr(endIdentifier - newCtx.query.begin());
if (!newCtx.query.empty() && newCtx.query[0] == ']') {
newCtx.query.erase(0, 1);
}
else {
Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
return false;
}
}
if (newCtx.firstSequence && newCtx.firstNode) {
// First node of first sequence = root node
for (auto node : mGraph->getNodes()) {
if (type.empty() || node->type() == type) {
if ((type.empty() || node->type() == type)
&& (lambda.empty() || mLambda.at(lambda)(node)))
{
MatchingResult result;
result.graph->add(node, false);
if (!anchor.empty()) {
......@@ -499,6 +525,7 @@ bool Aidge::GraphMatching::matchNode(Context& ctx, std::vector<MatchingResult>&
for (const auto& output : outputs) {
for (const auto& node : output) {
if ((type.empty() || node.first->type() == type)
&& (lambda.empty() || mLambda.at(lambda)(node.first))
&& (newCtx.edgeRightIdx == gk_IODefaultIndex || node.second == newCtx.edgeRightIdx))
{
if (mGraph->inView(node.first) && !newMatches[i].graph->inView(node.first)) {
......@@ -530,6 +557,7 @@ bool Aidge::GraphMatching::matchNode(Context& ctx, std::vector<MatchingResult>&
for (const auto& input : inputs) {
if ((type.empty() || input.first->type() == type)
&& (lambda.empty() || mLambda.at(lambda)(input.first))
&& (newCtx.edgeRightIdx == gk_IODefaultIndex || input.second == newCtx.edgeRightIdx))
{
if (mGraph->inView(input.first) && !newMatches[i].graph->inView(input.first)) {
......
......@@ -30,7 +30,7 @@ TEST_CASE("[core/graph] Matching") {
ReLU("relu1"),
PaddedConv(4, 8, {5, 5}, "conv2", {1, 1}, {2, 2, 2, 2}),
ReLU("relu2"),
PaddedConv(8, 16, {5, 5}, "conv3", {1, 1}, {2, 2, 2, 2}),
PaddedConv(8, 16, {3, 3}, "conv3", {1, 1}, {2, 2, 2, 2}),
ReLU("relu3"),
PaddedConv(8, 16, {5, 5}, "conv4", {1, 1}, {2, 2, 2, 2}),
Add(2, "add"),
......@@ -174,4 +174,21 @@ TEST_CASE("[core/graph] Matching") {
const auto results = GraphMatching(g1).match("ReLU-*>((Pad->Conv-*>Add)&Add)");
REQUIRE(results.size() == 0);
}
SECTION("Pad->Conv[3x3]->ReLU") {
auto gm = GraphMatching(g1);
gm.addLambda("3x3", [](const NodePtr& node) {
const std::shared_ptr<Conv_Op<2>> op =
std::static_pointer_cast<Conv_Op<2>>(node->getOperator());
return (op->getAttr<std::array<DimSize_t, 2>>("KernelDims") == std::array<DimSize_t, 2>({3, 3}));
});
const auto results = gm.match("Pad->Conv[3x3]->ReLU");
REQUIRE(results.size() == 1);
for (const auto& result : results) {
fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
REQUIRE(result.graph->getNodes().size() == 3);
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment