diff --git a/include/aidge/graph/Matching.hpp b/include/aidge/graph/Matching.hpp new file mode 100644 index 0000000000000000000000000000000000000000..de3dddc3af49a93b403c216a382171b4612e36c4 --- /dev/null +++ b/include/aidge/graph/Matching.hpp @@ -0,0 +1,156 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_GRAPH_MATCHING_H_ +#define AIDGE_CORE_GRAPH_MATCHING_H_ + +#include <map> +#include <memory> + +#include "aidge/graph/Node.hpp" +#include "aidge/graph/GraphView.hpp" + +namespace Aidge { +class GraphMatching { +public: + struct Context { + std::string query; + bool firstSequence = true; + bool firstNode = true; + bool inSequence = false; + bool lookForChild = true; + IOIndex_t edgeLeftIdx = 0; + IOIndex_t edgeRightIdx = 0; + size_t depth = 0; + }; + + struct MatchingResult { + std::shared_ptr<GraphView> graph; + std::map<std::string, std::map<std::string, NodePtr>> anchors; + NodePtr startNode; + + MatchingResult() { + graph = std::make_shared<GraphView>(); + } + + MatchingResult(const MatchingResult& result) { + graph = std::make_shared<GraphView>(*(result.graph.get())); + anchors = result.anchors; + startNode = result.startNode; + } + + MatchingResult& operator=(const MatchingResult& result) { + graph = std::make_shared<GraphView>(*(result.graph.get())); + anchors = result.anchors; + startNode = result.startNode; + return *this; + } + }; + + GraphMatching(std::shared_ptr<GraphView> graph) : mGraph(graph) {} + + /** + * Some rules: + * - The first node of the first sequence is the root node and cannot be optional + * WRONG: Conv?->ReLU + * GOOD: ReLU<-Conv? + * + * - The first node of any further sequence must be an existing anchor + * (the anchor cannot be in the middle of the sequence) + * WRONG: Conv->ReLU;Pad->Conv + * Pad->Conv;Conv->ReLU + * GOOD: Conv#->ReLU;Conv#<-Pad + * Pad->Conv#;Conv#->ReLU + * + * - Any node already matched cannot be matched again + * + * - When several nodes could match for a given node query, the first one + * not already in the matching result is matched, following the + * childs/parents ordered node list + * EXAMPLE: Producer in "Conv<*-Producer" will match the weights Producer first + * EXAMPLE: Producer in "Conv#<1-.;Conv#<*-Producer" will match the bias Producer + * because the weights Producer has already been matched + * + * - One always matches a sub-graph: additional connections can exist anywhere + * in the matched sub-graph + * EXAMPLE: "Add<*-." will match the Add operator and its first input, any + * additional inputs will not be included in the result + * EXAMPLE: "(Add#<*-.)+" will match the Add operator and all of its inputs + * Note that the anchor is required since we intend to match several + * inputs of the same node! + * + * - Matching is greedy: the matching GraphView results can be overlapping + * (the same node can be found in different results, except for the root rode) + * EXAMPLE: assume graph Conv#1->ReLU#1->Conv#2->ReLU#2 + * "Conv->ReLU?->Conv?->ReLU?" will match both + * Conv#1->ReLU#1->Conv#2->ReLU#2 and Conv#2->ReLU#2 + * + * - Whitespaces are allowed anywhere in the query + * + * QUERY = SEQ | NODE_OR_BLOCK (';' (SEQ | NODE_OR_BLOCK))* + */ + std::vector<MatchingResult> match(const std::string& query); + +private: + std::shared_ptr<GraphView> mGraph; + + /** + * NODE_OR_BLOCK = BLOCK | NODE + */ + bool matchNodeOrBlock(Context& ctx, std::vector<MatchingResult>& matches); + + /** + * QUANTIFIER = '?' | '*' | '+' | ('{' [0-9]+ '}') + * BLOCK = '(' SEQ | PAR | BLOCK | ALT | NODE ')' QUANTIFIER? + */ + bool matchBlock(Context& ctx, std::vector<MatchingResult>& matches); + + /** + * SEQ = NODE_OR_BLOCK (EDGE NODE_OR_BLOCK)+ + */ + bool matchSequence(Context& ctx, std::vector<MatchingResult>& matches); + + /** + * PAR = NODE_OR_BLOCK ('&' NODE_OR_BLOCK)+ + */ + bool matchParallel(Context& ctx, std::vector<MatchingResult>& matches); + + /** + * ALT = NODE_OR_BLOCK ('|' NODE_OR_BLOCK)+ + */ + bool matchAlternative(Context& ctx, std::vector<MatchingResult>& matches); + + /** + * IO_INDEX_ANY = '*' + * IO_INDEX = IO_INDEX_ANY | [0-9]+ + * CHILD_EDGE = '-' (IO_INDEX '-')? IO_INDEX? '>' + * PARENT_EDGE = '<' (IO_INDEX '-')? IO_INDEX? '-' + * EDGE = CHILD_EDGE | PARENT_EDGE + */ + bool matchEdge(Context& ctx, std::vector<MatchingResult>& matches); + + /** + * TYPE = [A-Za-z0-9_]+ + * ANCHOR = [A-Za-z0-9_]+ + * NODE = (TYPE | '.') ('#' ANCHOR)? + */ + bool matchNode(Context& ctx, std::vector<MatchingResult>& matches); + + inline void removeWhiteSpace(std::string& str) { + str.erase(str.begin(), + std::find_if(str.begin(), + str.end(), + std::not1(std::ptr_fun<int, int>(std::isspace)))); + } +}; +} // namespace Aidge + +#endif /* AIDGE_CORE_GRAPH_MATCHING_H_ */ diff --git a/src/graph/Matching.cpp b/src/graph/Matching.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2b763c13aafad8fae266d917fe2cf7afdc0d24cf --- /dev/null +++ b/src/graph/Matching.cpp @@ -0,0 +1,508 @@ +#include "aidge/graph/Matching.hpp" + +#include <fmt/color.h> + +std::vector<Aidge::GraphMatching::MatchingResult> Aidge::GraphMatching::match(const std::string& query) { + Context ctx; + ctx.query = query; + std::vector<MatchingResult> matches; + + while (matchSequence(ctx, matches) || matchNodeOrBlock(ctx, matches)) { + removeWhiteSpace(ctx.query); + if (!ctx.query.empty() && ctx.query[0] == ';') { + ctx.query.erase(0, 1); + } + else { + break; + } + } + + removeWhiteSpace(ctx.query); + if (!ctx.query.empty()) { + Log::warn("Syntax error, unable to parse remaining query: {}", ctx.query); + } + + return matches; +} + +bool Aidge::GraphMatching::matchNodeOrBlock(Context& ctx, std::vector<MatchingResult>& matches) { + auto newCtx = ctx; + Log::debug("{}node-or-block", std::string(2*newCtx.depth, ' ')); + auto newMatches = matches; + ++newCtx.depth; + + if (!matchBlock(newCtx, newMatches) && !matchNode(newCtx, newMatches)) { + Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red))); + return false; + } + + bool matchMore = false; + size_t matchQuantity = 0; + removeWhiteSpace(newCtx.query); + if (!newCtx.query.empty() && (newCtx.query[0] == '?' || newCtx.query[0] == '*')) { + for (const auto& match : matches) { + bool found = false; + for (const auto& newMatch : newMatches) { + if (match.graph->rootNode() == newMatch.graph->rootNode()) { + found = true; + } + } + + if (!found) { + newMatches.push_back(match); + } + } + + if (newCtx.query[0] == '*') { + matchMore = true; + } + + newCtx.query.erase(0, 1); + } + else if (!newCtx.query.empty() && newCtx.query[0] == '+') { + newCtx.query.erase(0, 1); + matchMore = true; + } + else if (!newCtx.query.empty() && newCtx.query[0] == '{') { + newCtx.query.erase(0, 1); + + removeWhiteSpace(newCtx.query); + const auto endQuantity = std::find_if(newCtx.query.begin(), newCtx.query.end(), + [](char c) { return !isdigit(c); }); + if (endQuantity != newCtx.query.begin()) { + matchQuantity = std::stoi(newCtx.query.substr(0, endQuantity - newCtx.query.begin())); + newCtx.query = newCtx.query.substr(endQuantity - newCtx.query.begin()); + } + else { + Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red))); + return false; + } + + if (matchQuantity == 0) { + Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red))); + return false; + } + + removeWhiteSpace(newCtx.query); + if (!newCtx.query.empty() && newCtx.query[0] == '}') { + newCtx.query.erase(0, 1); + } + else { + Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red))); + return false; + } + + if (matchQuantity > 1) { + matchMore = true; + } + } + + if (matchMore) { + std::vector<MatchingResult> additionalMatches; + + do { + auto additionalCtx = ctx; + additionalCtx.firstNode = newCtx.firstNode; + additionalCtx.firstSequence = newCtx.firstSequence; + ++additionalCtx.depth; + additionalMatches = newMatches; + + if (!matchBlock(additionalCtx, additionalMatches) && !matchNode(additionalCtx, additionalMatches)) { + Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red))); + return false; + } + + for (const auto& additionalMatch : additionalMatches) { + for (auto& match : newMatches) { + if (match.graph->rootNode() == additionalMatch.graph->rootNode()) { + match = additionalMatch; + break; + } + } + } + + --matchQuantity; + } + while (!additionalMatches.empty() && matchQuantity > 1); + } + + --newCtx.depth; + ctx = newCtx; + matches = newMatches; + return true; +} + +bool Aidge::GraphMatching::matchBlock(Context& ctx, std::vector<MatchingResult>& matches) { + auto newCtx = ctx; + Log::debug("{}block", std::string(2*newCtx.depth, ' ')); + auto newMatches = matches; + ++newCtx.depth; + + removeWhiteSpace(newCtx.query); + if (!newCtx.query.empty() && newCtx.query[0] == '(') { + newCtx.query.erase(0, 1); + } + else { + Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red))); + return false; + } + + if (!matchSequence(newCtx, newMatches) + && !matchParallel(newCtx, newMatches) + && !matchBlock(newCtx, newMatches) + && !matchAlternative(newCtx, newMatches) + && !matchNode(newCtx, newMatches)) + { + Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red))); + return false; + } + + removeWhiteSpace(newCtx.query); + if (!newCtx.query.empty() && newCtx.query[0] == ')') { + newCtx.query.erase(0, 1); + } + else { + Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red))); + return false; + } + + --newCtx.depth; + ctx = newCtx; + matches = newMatches; + return true; +} + +bool Aidge::GraphMatching::matchSequence(Context& ctx, std::vector<MatchingResult>& matches) { + auto newCtx = ctx; + Log::debug("{}sequence", std::string(2*newCtx.depth, ' ')); + auto newMatches = matches; + ++newCtx.depth; + + if (!ctx.inSequence) { + newCtx.inSequence = true; + newCtx.firstNode = true; + } + + if (!matchNodeOrBlock(newCtx, newMatches)) { + Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red))); + return false; + } + + newCtx.firstNode = false; + + bool found = false; + while (true) { + if (matchEdge(newCtx, newMatches)) { + found = true; + } + else { + break; + } + + if (!matchNodeOrBlock(newCtx, newMatches)) { + Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red))); + return false; + } + } + + if (!found) { + Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red))); + return false; + } + + if (!ctx.inSequence) { + newCtx.inSequence = false; + } + + --newCtx.depth; + ctx = newCtx; + matches = newMatches; + return true; +} + +bool Aidge::GraphMatching::matchParallel(Context& /*ctx*/, std::vector<MatchingResult>& /*matches*/) { + // TODO + return false; +} + +bool Aidge::GraphMatching::matchAlternative(Context& ctx, std::vector<MatchingResult>& matches) { + auto newCtx = ctx; + Log::debug("{}alternative", std::string(2*newCtx.depth, ' ')); + ++newCtx.depth; + std::vector<MatchingResult> newMatches; + + auto altCtx = newCtx; + auto altMatches = matches; + if (!matchNodeOrBlock(altCtx, altMatches)) { + Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red))); + return false; + } + newCtx.query = altCtx.query; + newMatches.insert(newMatches.end(), altMatches.begin(), altMatches.end()); + + bool found = false; + while (true) { + removeWhiteSpace(newCtx.query); + if (!newCtx.query.empty() && newCtx.query[0] == '|') { + newCtx.query.erase(0, 1); + found = true; + } + else { + break; + } + + altCtx = newCtx; + altMatches = matches; + if (!matchNodeOrBlock(altCtx, altMatches)) { + Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red))); + return false; + } + newCtx.query = altCtx.query; + newMatches.insert(newMatches.end(), altMatches.begin(), altMatches.end()); + } + + if (!found) { + Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red))); + return false; + } + + --newCtx.depth; + ctx = newCtx; + matches = newMatches; + return true; +} + +bool Aidge::GraphMatching::matchEdge(Context& ctx, std::vector<MatchingResult>& /*matches*/) { + auto newCtx = ctx; + Log::debug("{}edge", std::string(2*newCtx.depth, ' ')); + + removeWhiteSpace(newCtx.query); + if (!newCtx.query.empty() && newCtx.query[0] == '-') { + newCtx.query.erase(0, 1); // drop '-' + newCtx.lookForChild = true; + } + else if (!newCtx.query.empty() && newCtx.query[0] == '<') { + newCtx.query.erase(0, 1); // drop '<' + newCtx.lookForChild = false; + } + else { + Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red))); + return false; + } + + int firstIdx = 0; + bool foundFirst = false; + const auto endOutputIdx = std::find_if(newCtx.query.begin(), newCtx.query.end(), + [](char c) { return !isdigit(c); }); + if (endOutputIdx != newCtx.query.begin()) { + firstIdx = std::stoi(newCtx.query.substr(0, endOutputIdx - newCtx.query.begin())); + newCtx.query = newCtx.query.substr(endOutputIdx - newCtx.query.begin()); + foundFirst = true; + } + else if (newCtx.query[0] == '*') { + newCtx.query.erase(0, 1); // drop '*' + firstIdx = -1; + foundFirst = true; + } + + int secondIdx = 0; + bool foundSecond = false; + if (foundFirst && !newCtx.query.empty() && newCtx.query[0] == '-') { + auto query = newCtx.query; + query.erase(0, 1); // drop '-' + + const auto endInputIdx = std::find_if(query.begin(), query.end(), + [](char c) { return !isdigit(c); }); + if (endInputIdx != query.begin()) { + secondIdx = std::stoi(query.substr(0, endInputIdx - query.begin())); + query = query.substr(endInputIdx - query.begin()); + foundSecond = true; + } + else if (query[0] == '*') { + query.erase(0, 1); // drop '*' + secondIdx = -1; + foundSecond = true; + } + + if (foundSecond) { + newCtx.query = query; + } + } + + if (newCtx.lookForChild && !newCtx.query.empty() && newCtx.query[0] == '>') { + newCtx.query.erase(0, 1); // drop '>' + } + else if (!newCtx.lookForChild && !newCtx.query.empty() && newCtx.query[0] == '-') { + newCtx.query.erase(0, 1); // drop '-' + } + else { + Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red))); + return false; + } + + newCtx.edgeLeftIdx = 0; + newCtx.edgeRightIdx = 0; + if (foundFirst && foundSecond) { + newCtx.edgeLeftIdx = firstIdx; + newCtx.edgeRightIdx = secondIdx; + } + else if (foundFirst) { + if (newCtx.lookForChild) { + newCtx.edgeRightIdx = firstIdx; + } + else { + newCtx.edgeLeftIdx = firstIdx; + } + } + + if (newCtx.lookForChild) { + Log::debug("{}-{}-{}>", std::string(2*newCtx.depth + 2, ' '), + newCtx.edgeLeftIdx, newCtx.edgeRightIdx); + } + else { + Log::debug("{}<{}-{}-", std::string(2*newCtx.depth + 2, ' '), + newCtx.edgeLeftIdx, newCtx.edgeRightIdx); + } + + ctx = newCtx; + return true; +} + +bool Aidge::GraphMatching::matchNode(Context& ctx, std::vector<MatchingResult>& matches) { + auto newCtx = ctx; + Log::debug("{}node", std::string(2*newCtx.depth, ' ')); + auto newMatches = matches; + + removeWhiteSpace(newCtx.query); + if (newCtx.query.empty()) { + Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red))); + return false; + } + + std::string type; + if (newCtx.query[0] == '.') { + newCtx.query.erase(0, 1); // drop '.' + } + else { + const auto endIdentifier = std::find_if(newCtx.query.begin(), newCtx.query.end(), + [](char c) { return (!isalnum(c) && c != '_'); }); + + if (endIdentifier == newCtx.query.begin()) { + Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red))); + return false; + } + + type = newCtx.query.substr(0, endIdentifier - newCtx.query.begin()); + newCtx.query = newCtx.query.substr(endIdentifier - newCtx.query.begin()); + } + + std::string anchor = ""; + if (!newCtx.query.empty() && newCtx.query[0] == '#') { + newCtx.query.erase(0, 1); // drop '#' + const auto endAnchor = std::find_if(newCtx.query.begin(), newCtx.query.end(), + [](char c) { return (!isalnum(c) && c != '_'); }); + anchor = "#" + newCtx.query.substr(0, endAnchor - newCtx.query.begin()); + newCtx.query = newCtx.query.substr(endAnchor - newCtx.query.begin()); + } + + if (newCtx.firstSequence && newCtx.firstNode) { + // First node of first sequence = root node + for (auto node : mGraph->getNodes()) { + if (type.empty() || node->type() == type) { + MatchingResult result; + result.graph->add(node, false); + if (!anchor.empty()) { + result.anchors[type][anchor] = node; + } + result.startNode = node; + newMatches.push_back(result); + } + } + newCtx.firstSequence = false; + + Log::debug("{}root node {}{}, found: {}", std::string(2*newCtx.depth + 2, ' '), fmt::styled(type.empty() ? "." : type, fmt::fg(fmt::color::yellow)), anchor, newMatches.size()); + } + else if (newCtx.firstNode) { + // First node of a (new) sequence: it has to be an existing anchor + for (size_t i = 0; i < newMatches.size(); ) { + const auto anchors = newMatches[i].anchors[type]; + const auto anchorNode = anchors.find(anchor); + if (anchorNode != anchors.end()) { + newMatches[i].startNode = anchorNode->second; + ++i; + } + else { + newMatches.erase(newMatches.begin() + i); + } + } + + Log::debug("{}anchor node {}{}, found: {}", std::string(2*newCtx.depth + 2, ' '), fmt::styled(type.empty() ? "." : type, fmt::fg(fmt::color::yellow)), anchor, newMatches.size()); + } + else { + for (size_t i = 0; i < newMatches.size(); ) { + bool found = false; + + if (newCtx.lookForChild) { + const auto outputs = (newCtx.edgeLeftIdx != gk_IODefaultIndex) + ? std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>>(1, std::vector<std::pair<NodePtr, IOIndex_t>>(newMatches[i].startNode->output(newCtx.edgeLeftIdx))) + : newMatches[i].startNode->outputs(); + + for (const auto& output : outputs) { + for (const auto& node : output) { + if ((type.empty() || node.first->type() == type) + && (newCtx.edgeRightIdx == gk_IODefaultIndex || node.second == newCtx.edgeRightIdx)) + { + if (mGraph->inView(node.first) && !newMatches[i].graph->inView(node.first)) { + newMatches[i].graph->add(node.first, false); + if (!anchor.empty()) { + newMatches[i].anchors[type][anchor] = node.first; + } + newMatches[i].startNode = node.first; + found = true; + break; + } + } + } + + if (found) { + break; + } + } + } + else { + const auto inputs = (newCtx.edgeLeftIdx != gk_IODefaultIndex) + ? std::vector<std::pair<NodePtr, IOIndex_t>>(1, newMatches[i].startNode->input(newCtx.edgeLeftIdx)) + : newMatches[i].startNode->inputs(); + + for (const auto& input : inputs) { + if ((type.empty() || input.first->type() == type) + && (newCtx.edgeRightIdx == gk_IODefaultIndex || input.second == newCtx.edgeRightIdx)) + { + if (mGraph->inView(input.first) && !newMatches[i].graph->inView(input.first)) { + newMatches[i].graph->add(input.first, false); + if (!anchor.empty()) { + newMatches[i].anchors[type][anchor] = input.first; + } + newMatches[i].startNode = input.first; + found = true; + break; + } + } + } + } + + if (found) { + ++i; + } + else { + newMatches.erase(newMatches.begin() + i); + } + } + + Log::debug("{}node {}{}, found: {}", std::string(2*newCtx.depth + 2, ' '), fmt::styled(type.empty() ? "." : type, fmt::fg(fmt::color::yellow)), anchor, newMatches.size()); + } + + ctx = newCtx; + matches = newMatches; + return true; +} diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1613ab1beda7eb1618608d8b0dbb4a7fb7fbaa67 --- /dev/null +++ b/unit_tests/graph/Test_Matching.cpp @@ -0,0 +1,194 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/OpArgs.hpp" +#include "aidge/operator/ReLU.hpp" +#include "aidge/operator/MetaOperatorDefs.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/graph/Matching.hpp" +#include "aidge/recipes/Recipes.hpp" + +using namespace Aidge; + +TEST_CASE("[core/graph] Matching") { + auto g1 = Sequential({ + Producer({16, 3, 512, 512}, "dataProvider"), + Conv(3, 4, {5, 5}, "conv1"), + ReLU("relu1"), + PaddedConv(4, 8, {5, 5}, "conv2", {1, 1}, {2, 2, 2, 2}), + ReLU("relu2"), + PaddedConv(8, 16, {5, 5}, "conv3", {1, 1}, {2, 2, 2, 2}), + ReLU("relu3"), + PaddedConv(8, 16, {5, 5}, "conv4", {1, 1}, {2, 2, 2, 2}), + Add(2, "add"), + PaddedConv(8, 16, {5, 5}, "conv5", {1, 1}, {2, 2, 2, 2}), + ReLU("relu5"), + Add(2, "add2") + }); + + g1->getNode("relu3")->addChild(g1->getNode("add"), 0, 1); + g1->getNode("conv5")->addChild(g1->getNode("add2"), 0, 1); + g1->updateInputsOutputs(); + + g1->save("Test_examples_before_expand", true); + expandMetaOps(g1); + g1->save("Test_examples", true); + + SECTION("Conv->(ReLU->Pad->Conv)*") { + auto results = GraphMatching(g1).match("Conv->(ReLU->Pad->Conv)*"); + REQUIRE(results.size() == 5); + + for (auto result : results) { + std::vector<std::string> nodesName; + std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(), + std::back_inserter(nodesName), + [](auto val){ return val->name(); }); + fmt::print("Found: {}\n", nodesName); + } + } + + SECTION("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer") { + auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer"); + REQUIRE(results.size() == 3); + + for (auto result : results) { + std::vector<std::string> nodesName; + std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(), + std::back_inserter(nodesName), + [](auto val){ return val->name(); }); + fmt::print("Found: {}\n", nodesName); + + REQUIRE(result.graph->getNodes().size() == 5); + } + } + + SECTION("Pad->Conv#->ReLU;(Conv#<*-Producer){2}") { + auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;(Conv#<*-Producer){2}"); + REQUIRE(results.size() == 3); + + for (auto result : results) { + std::vector<std::string> nodesName; + std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(), + std::back_inserter(nodesName), + [](auto val){ return val->name(); }); + fmt::print("Found: {}\n", nodesName); + + REQUIRE(result.graph->getNodes().size() == 5); + } + } + + SECTION("Pad->Conv#->ReLU;(Conv#<*-.){2}") { + auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;(Conv#<*-.){2}"); + REQUIRE(results.size() == 3); + + for (auto result : results) { + std::vector<std::string> nodesName; + std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(), + std::back_inserter(nodesName), + [](auto val){ return val->name(); }); + fmt::print("Found: {}\n", nodesName); + + REQUIRE(result.graph->getNodes().size() == 5); + } + } + + SECTION("Conv#->ReLU*;Conv#<-Pad*") { + auto results = GraphMatching(g1).match("Conv#->ReLU*;Conv#<-Pad*"); + REQUIRE(results.size() == 5); + + for (auto result : results) { + std::vector<std::string> nodesName; + std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(), + std::back_inserter(nodesName), + [](auto val){ return val->name(); }); + fmt::print("Found: {}\n", nodesName); + + REQUIRE((result.graph->getNodes().size() == 2 || result.graph->getNodes().size() == 3)); + } + } + + SECTION("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?") { + auto results = GraphMatching(g1).match("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?"); + REQUIRE(results.size() == 5); + + for (auto result : results) { + std::vector<std::string> nodesName; + std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(), + std::back_inserter(nodesName), + [](auto val){ return val->name(); }); + fmt::print("Found: {}\n", nodesName); + } + } + + SECTION("Conv#->ReLU?;Conv#<-Pad?") { + auto results = GraphMatching(g1).match("Conv#->ReLU?;Conv#<-Pad?"); + REQUIRE(results.size() == 5); + + for (auto result : results) { + std::vector<std::string> nodesName; + std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(), + std::back_inserter(nodesName), + [](auto val){ return val->name(); }); + fmt::print("Found: {}\n", nodesName); + + REQUIRE((result.graph->getNodes().size() == 2 || result.graph->getNodes().size() == 3)); + } + } + + SECTION("(Conv|ReLU)->Add") { + auto results = GraphMatching(g1).match("(Conv|ReLU)->Add"); + REQUIRE(results.size() == 2); + + for (auto result : results) { + std::vector<std::string> nodesName; + std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(), + std::back_inserter(nodesName), + [](auto val){ return val->name(); }); + fmt::print("Found: {}\n", nodesName); + + REQUIRE(result.graph->getNodes().size() == 2); + } + } + + SECTION("Add<*-.") { + auto results = GraphMatching(g1).match("Add<*-."); + REQUIRE(results.size() == 2); + + for (auto result : results) { + std::vector<std::string> nodesName; + std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(), + std::back_inserter(nodesName), + [](auto val){ return val->name(); }); + fmt::print("Found: {}\n", nodesName); + + REQUIRE(result.graph->getNodes().size() == 2); + } + } + + SECTION("(Add#<*-.)+") { + auto results = GraphMatching(g1).match("(Add#<*-.)+"); + REQUIRE(results.size() == 2); + + for (auto result : results) { + std::vector<std::string> nodesName; + std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(), + std::back_inserter(nodesName), + [](auto val){ return val->name(); }); + fmt::print("Found: {}\n", nodesName); + + REQUIRE(result.graph->getNodes().size() == 3); + } + } +}