Skip to content
Snippets Groups Projects
Commit 513273b5 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

First working concept

parent 1a0d54d5
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!138Alternative graph matching
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_GRAPH_MATCHING_H_
#define AIDGE_CORE_GRAPH_MATCHING_H_
#include <map>
#include <memory>
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
namespace Aidge {
class GraphMatching {
public:
struct Context {
std::string query;
bool firstSequence = true;
bool firstNode = true;
bool inSequence = false;
bool lookForChild = true;
IOIndex_t edgeLeftIdx = 0;
IOIndex_t edgeRightIdx = 0;
size_t depth = 0;
};
struct MatchingResult {
std::shared_ptr<GraphView> graph;
std::map<std::string, std::map<std::string, NodePtr>> anchors;
NodePtr startNode;
MatchingResult() {
graph = std::make_shared<GraphView>();
}
MatchingResult(const MatchingResult& result) {
graph = std::make_shared<GraphView>(*(result.graph.get()));
anchors = result.anchors;
startNode = result.startNode;
}
MatchingResult& operator=(const MatchingResult& result) {
graph = std::make_shared<GraphView>(*(result.graph.get()));
anchors = result.anchors;
startNode = result.startNode;
return *this;
}
};
GraphMatching(std::shared_ptr<GraphView> graph) : mGraph(graph) {}
/**
* Some rules:
* - The first node of the first sequence is the root node and cannot be optional
* WRONG: Conv?->ReLU
* GOOD: ReLU<-Conv?
*
* - The first node of any further sequence must be an existing anchor
* (the anchor cannot be in the middle of the sequence)
* WRONG: Conv->ReLU;Pad->Conv
* Pad->Conv;Conv->ReLU
* GOOD: Conv#->ReLU;Conv#<-Pad
* Pad->Conv#;Conv#->ReLU
*
* - Any node already matched cannot be matched again
*
* - When several nodes could match for a given node query, the first one
* not already in the matching result is matched, following the
* childs/parents ordered node list
* EXAMPLE: Producer in "Conv<*-Producer" will match the weights Producer first
* EXAMPLE: Producer in "Conv#<1-.;Conv#<*-Producer" will match the bias Producer
* because the weights Producer has already been matched
*
* - One always matches a sub-graph: additional connections can exist anywhere
* in the matched sub-graph
* EXAMPLE: "Add<*-." will match the Add operator and its first input, any
* additional inputs will not be included in the result
* EXAMPLE: "(Add#<*-.)+" will match the Add operator and all of its inputs
* Note that the anchor is required since we intend to match several
* inputs of the same node!
*
* - Matching is greedy: the matching GraphView results can be overlapping
* (the same node can be found in different results, except for the root rode)
* EXAMPLE: assume graph Conv#1->ReLU#1->Conv#2->ReLU#2
* "Conv->ReLU?->Conv?->ReLU?" will match both
* Conv#1->ReLU#1->Conv#2->ReLU#2 and Conv#2->ReLU#2
*
* - Whitespaces are allowed anywhere in the query
*
* QUERY = SEQ | NODE_OR_BLOCK (';' (SEQ | NODE_OR_BLOCK))*
*/
std::vector<MatchingResult> match(const std::string& query);
private:
std::shared_ptr<GraphView> mGraph;
/**
* NODE_OR_BLOCK = BLOCK | NODE
*/
bool matchNodeOrBlock(Context& ctx, std::vector<MatchingResult>& matches);
/**
* QUANTIFIER = '?' | '*' | '+' | ('{' [0-9]+ '}')
* BLOCK = '(' SEQ | PAR | BLOCK | ALT | NODE ')' QUANTIFIER?
*/
bool matchBlock(Context& ctx, std::vector<MatchingResult>& matches);
/**
* SEQ = NODE_OR_BLOCK (EDGE NODE_OR_BLOCK)+
*/
bool matchSequence(Context& ctx, std::vector<MatchingResult>& matches);
/**
* PAR = NODE_OR_BLOCK ('&' NODE_OR_BLOCK)+
*/
bool matchParallel(Context& ctx, std::vector<MatchingResult>& matches);
/**
* ALT = NODE_OR_BLOCK ('|' NODE_OR_BLOCK)+
*/
bool matchAlternative(Context& ctx, std::vector<MatchingResult>& matches);
/**
* IO_INDEX_ANY = '*'
* IO_INDEX = IO_INDEX_ANY | [0-9]+
* CHILD_EDGE = '-' (IO_INDEX '-')? IO_INDEX? '>'
* PARENT_EDGE = '<' (IO_INDEX '-')? IO_INDEX? '-'
* EDGE = CHILD_EDGE | PARENT_EDGE
*/
bool matchEdge(Context& ctx, std::vector<MatchingResult>& matches);
/**
* TYPE = [A-Za-z0-9_]+
* ANCHOR = [A-Za-z0-9_]+
* NODE = (TYPE | '.') ('#' ANCHOR)?
*/
bool matchNode(Context& ctx, std::vector<MatchingResult>& matches);
inline void removeWhiteSpace(std::string& str) {
str.erase(str.begin(),
std::find_if(str.begin(),
str.end(),
std::not1(std::ptr_fun<int, int>(std::isspace))));
}
};
} // namespace Aidge
#endif /* AIDGE_CORE_GRAPH_MATCHING_H_ */
#include "aidge/graph/Matching.hpp"
#include <fmt/color.h>
std::vector<Aidge::GraphMatching::MatchingResult> Aidge::GraphMatching::match(const std::string& query) {
Context ctx;
ctx.query = query;
std::vector<MatchingResult> matches;
while (matchSequence(ctx, matches) || matchNodeOrBlock(ctx, matches)) {
removeWhiteSpace(ctx.query);
if (!ctx.query.empty() && ctx.query[0] == ';') {
ctx.query.erase(0, 1);
}
else {
break;
}
}
removeWhiteSpace(ctx.query);
if (!ctx.query.empty()) {
Log::warn("Syntax error, unable to parse remaining query: {}", ctx.query);
}
return matches;
}
bool Aidge::GraphMatching::matchNodeOrBlock(Context& ctx, std::vector<MatchingResult>& matches) {
auto newCtx = ctx;
Log::debug("{}node-or-block", std::string(2*newCtx.depth, ' '));
auto newMatches = matches;
++newCtx.depth;
if (!matchBlock(newCtx, newMatches) && !matchNode(newCtx, newMatches)) {
Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
return false;
}
bool matchMore = false;
size_t matchQuantity = 0;
removeWhiteSpace(newCtx.query);
if (!newCtx.query.empty() && (newCtx.query[0] == '?' || newCtx.query[0] == '*')) {
for (const auto& match : matches) {
bool found = false;
for (const auto& newMatch : newMatches) {
if (match.graph->rootNode() == newMatch.graph->rootNode()) {
found = true;
}
}
if (!found) {
newMatches.push_back(match);
}
}
if (newCtx.query[0] == '*') {
matchMore = true;
}
newCtx.query.erase(0, 1);
}
else if (!newCtx.query.empty() && newCtx.query[0] == '+') {
newCtx.query.erase(0, 1);
matchMore = true;
}
else if (!newCtx.query.empty() && newCtx.query[0] == '{') {
newCtx.query.erase(0, 1);
removeWhiteSpace(newCtx.query);
const auto endQuantity = std::find_if(newCtx.query.begin(), newCtx.query.end(),
[](char c) { return !isdigit(c); });
if (endQuantity != newCtx.query.begin()) {
matchQuantity = std::stoi(newCtx.query.substr(0, endQuantity - newCtx.query.begin()));
newCtx.query = newCtx.query.substr(endQuantity - newCtx.query.begin());
}
else {
Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
return false;
}
if (matchQuantity == 0) {
Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
return false;
}
removeWhiteSpace(newCtx.query);
if (!newCtx.query.empty() && newCtx.query[0] == '}') {
newCtx.query.erase(0, 1);
}
else {
Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
return false;
}
if (matchQuantity > 1) {
matchMore = true;
}
}
if (matchMore) {
std::vector<MatchingResult> additionalMatches;
do {
auto additionalCtx = ctx;
additionalCtx.firstNode = newCtx.firstNode;
additionalCtx.firstSequence = newCtx.firstSequence;
++additionalCtx.depth;
additionalMatches = newMatches;
if (!matchBlock(additionalCtx, additionalMatches) && !matchNode(additionalCtx, additionalMatches)) {
Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
return false;
}
for (const auto& additionalMatch : additionalMatches) {
for (auto& match : newMatches) {
if (match.graph->rootNode() == additionalMatch.graph->rootNode()) {
match = additionalMatch;
break;
}
}
}
--matchQuantity;
}
while (!additionalMatches.empty() && matchQuantity > 1);
}
--newCtx.depth;
ctx = newCtx;
matches = newMatches;
return true;
}
bool Aidge::GraphMatching::matchBlock(Context& ctx, std::vector<MatchingResult>& matches) {
auto newCtx = ctx;
Log::debug("{}block", std::string(2*newCtx.depth, ' '));
auto newMatches = matches;
++newCtx.depth;
removeWhiteSpace(newCtx.query);
if (!newCtx.query.empty() && newCtx.query[0] == '(') {
newCtx.query.erase(0, 1);
}
else {
Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
return false;
}
if (!matchSequence(newCtx, newMatches)
&& !matchParallel(newCtx, newMatches)
&& !matchBlock(newCtx, newMatches)
&& !matchAlternative(newCtx, newMatches)
&& !matchNode(newCtx, newMatches))
{
Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
return false;
}
removeWhiteSpace(newCtx.query);
if (!newCtx.query.empty() && newCtx.query[0] == ')') {
newCtx.query.erase(0, 1);
}
else {
Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
return false;
}
--newCtx.depth;
ctx = newCtx;
matches = newMatches;
return true;
}
bool Aidge::GraphMatching::matchSequence(Context& ctx, std::vector<MatchingResult>& matches) {
auto newCtx = ctx;
Log::debug("{}sequence", std::string(2*newCtx.depth, ' '));
auto newMatches = matches;
++newCtx.depth;
if (!ctx.inSequence) {
newCtx.inSequence = true;
newCtx.firstNode = true;
}
if (!matchNodeOrBlock(newCtx, newMatches)) {
Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
return false;
}
newCtx.firstNode = false;
bool found = false;
while (true) {
if (matchEdge(newCtx, newMatches)) {
found = true;
}
else {
break;
}
if (!matchNodeOrBlock(newCtx, newMatches)) {
Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
return false;
}
}
if (!found) {
Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
return false;
}
if (!ctx.inSequence) {
newCtx.inSequence = false;
}
--newCtx.depth;
ctx = newCtx;
matches = newMatches;
return true;
}
bool Aidge::GraphMatching::matchParallel(Context& /*ctx*/, std::vector<MatchingResult>& /*matches*/) {
// TODO
return false;
}
bool Aidge::GraphMatching::matchAlternative(Context& ctx, std::vector<MatchingResult>& matches) {
auto newCtx = ctx;
Log::debug("{}alternative", std::string(2*newCtx.depth, ' '));
++newCtx.depth;
std::vector<MatchingResult> newMatches;
auto altCtx = newCtx;
auto altMatches = matches;
if (!matchNodeOrBlock(altCtx, altMatches)) {
Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
return false;
}
newCtx.query = altCtx.query;
newMatches.insert(newMatches.end(), altMatches.begin(), altMatches.end());
bool found = false;
while (true) {
removeWhiteSpace(newCtx.query);
if (!newCtx.query.empty() && newCtx.query[0] == '|') {
newCtx.query.erase(0, 1);
found = true;
}
else {
break;
}
altCtx = newCtx;
altMatches = matches;
if (!matchNodeOrBlock(altCtx, altMatches)) {
Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
return false;
}
newCtx.query = altCtx.query;
newMatches.insert(newMatches.end(), altMatches.begin(), altMatches.end());
}
if (!found) {
Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
return false;
}
--newCtx.depth;
ctx = newCtx;
matches = newMatches;
return true;
}
bool Aidge::GraphMatching::matchEdge(Context& ctx, std::vector<MatchingResult>& /*matches*/) {
auto newCtx = ctx;
Log::debug("{}edge", std::string(2*newCtx.depth, ' '));
removeWhiteSpace(newCtx.query);
if (!newCtx.query.empty() && newCtx.query[0] == '-') {
newCtx.query.erase(0, 1); // drop '-'
newCtx.lookForChild = true;
}
else if (!newCtx.query.empty() && newCtx.query[0] == '<') {
newCtx.query.erase(0, 1); // drop '<'
newCtx.lookForChild = false;
}
else {
Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
return false;
}
int firstIdx = 0;
bool foundFirst = false;
const auto endOutputIdx = std::find_if(newCtx.query.begin(), newCtx.query.end(),
[](char c) { return !isdigit(c); });
if (endOutputIdx != newCtx.query.begin()) {
firstIdx = std::stoi(newCtx.query.substr(0, endOutputIdx - newCtx.query.begin()));
newCtx.query = newCtx.query.substr(endOutputIdx - newCtx.query.begin());
foundFirst = true;
}
else if (newCtx.query[0] == '*') {
newCtx.query.erase(0, 1); // drop '*'
firstIdx = -1;
foundFirst = true;
}
int secondIdx = 0;
bool foundSecond = false;
if (foundFirst && !newCtx.query.empty() && newCtx.query[0] == '-') {
auto query = newCtx.query;
query.erase(0, 1); // drop '-'
const auto endInputIdx = std::find_if(query.begin(), query.end(),
[](char c) { return !isdigit(c); });
if (endInputIdx != query.begin()) {
secondIdx = std::stoi(query.substr(0, endInputIdx - query.begin()));
query = query.substr(endInputIdx - query.begin());
foundSecond = true;
}
else if (query[0] == '*') {
query.erase(0, 1); // drop '*'
secondIdx = -1;
foundSecond = true;
}
if (foundSecond) {
newCtx.query = query;
}
}
if (newCtx.lookForChild && !newCtx.query.empty() && newCtx.query[0] == '>') {
newCtx.query.erase(0, 1); // drop '>'
}
else if (!newCtx.lookForChild && !newCtx.query.empty() && newCtx.query[0] == '-') {
newCtx.query.erase(0, 1); // drop '-'
}
else {
Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
return false;
}
newCtx.edgeLeftIdx = 0;
newCtx.edgeRightIdx = 0;
if (foundFirst && foundSecond) {
newCtx.edgeLeftIdx = firstIdx;
newCtx.edgeRightIdx = secondIdx;
}
else if (foundFirst) {
if (newCtx.lookForChild) {
newCtx.edgeRightIdx = firstIdx;
}
else {
newCtx.edgeLeftIdx = firstIdx;
}
}
if (newCtx.lookForChild) {
Log::debug("{}-{}-{}>", std::string(2*newCtx.depth + 2, ' '),
newCtx.edgeLeftIdx, newCtx.edgeRightIdx);
}
else {
Log::debug("{}<{}-{}-", std::string(2*newCtx.depth + 2, ' '),
newCtx.edgeLeftIdx, newCtx.edgeRightIdx);
}
ctx = newCtx;
return true;
}
bool Aidge::GraphMatching::matchNode(Context& ctx, std::vector<MatchingResult>& matches) {
auto newCtx = ctx;
Log::debug("{}node", std::string(2*newCtx.depth, ' '));
auto newMatches = matches;
removeWhiteSpace(newCtx.query);
if (newCtx.query.empty()) {
Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
return false;
}
std::string type;
if (newCtx.query[0] == '.') {
newCtx.query.erase(0, 1); // drop '.'
}
else {
const auto endIdentifier = std::find_if(newCtx.query.begin(), newCtx.query.end(),
[](char c) { return (!isalnum(c) && c != '_'); });
if (endIdentifier == newCtx.query.begin()) {
Log::debug("{}{}", std::string(2*ctx.depth, ' '), fmt::styled("×", fmt::fg(fmt::color::red)));
return false;
}
type = newCtx.query.substr(0, endIdentifier - newCtx.query.begin());
newCtx.query = newCtx.query.substr(endIdentifier - newCtx.query.begin());
}
std::string anchor = "";
if (!newCtx.query.empty() && newCtx.query[0] == '#') {
newCtx.query.erase(0, 1); // drop '#'
const auto endAnchor = std::find_if(newCtx.query.begin(), newCtx.query.end(),
[](char c) { return (!isalnum(c) && c != '_'); });
anchor = "#" + newCtx.query.substr(0, endAnchor - newCtx.query.begin());
newCtx.query = newCtx.query.substr(endAnchor - newCtx.query.begin());
}
if (newCtx.firstSequence && newCtx.firstNode) {
// First node of first sequence = root node
for (auto node : mGraph->getNodes()) {
if (type.empty() || node->type() == type) {
MatchingResult result;
result.graph->add(node, false);
if (!anchor.empty()) {
result.anchors[type][anchor] = node;
}
result.startNode = node;
newMatches.push_back(result);
}
}
newCtx.firstSequence = false;
Log::debug("{}root node {}{}, found: {}", std::string(2*newCtx.depth + 2, ' '), fmt::styled(type.empty() ? "." : type, fmt::fg(fmt::color::yellow)), anchor, newMatches.size());
}
else if (newCtx.firstNode) {
// First node of a (new) sequence: it has to be an existing anchor
for (size_t i = 0; i < newMatches.size(); ) {
const auto anchors = newMatches[i].anchors[type];
const auto anchorNode = anchors.find(anchor);
if (anchorNode != anchors.end()) {
newMatches[i].startNode = anchorNode->second;
++i;
}
else {
newMatches.erase(newMatches.begin() + i);
}
}
Log::debug("{}anchor node {}{}, found: {}", std::string(2*newCtx.depth + 2, ' '), fmt::styled(type.empty() ? "." : type, fmt::fg(fmt::color::yellow)), anchor, newMatches.size());
}
else {
for (size_t i = 0; i < newMatches.size(); ) {
bool found = false;
if (newCtx.lookForChild) {
const auto outputs = (newCtx.edgeLeftIdx != gk_IODefaultIndex)
? std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>>(1, std::vector<std::pair<NodePtr, IOIndex_t>>(newMatches[i].startNode->output(newCtx.edgeLeftIdx)))
: newMatches[i].startNode->outputs();
for (const auto& output : outputs) {
for (const auto& node : output) {
if ((type.empty() || node.first->type() == type)
&& (newCtx.edgeRightIdx == gk_IODefaultIndex || node.second == newCtx.edgeRightIdx))
{
if (mGraph->inView(node.first) && !newMatches[i].graph->inView(node.first)) {
newMatches[i].graph->add(node.first, false);
if (!anchor.empty()) {
newMatches[i].anchors[type][anchor] = node.first;
}
newMatches[i].startNode = node.first;
found = true;
break;
}
}
}
if (found) {
break;
}
}
}
else {
const auto inputs = (newCtx.edgeLeftIdx != gk_IODefaultIndex)
? std::vector<std::pair<NodePtr, IOIndex_t>>(1, newMatches[i].startNode->input(newCtx.edgeLeftIdx))
: newMatches[i].startNode->inputs();
for (const auto& input : inputs) {
if ((type.empty() || input.first->type() == type)
&& (newCtx.edgeRightIdx == gk_IODefaultIndex || input.second == newCtx.edgeRightIdx))
{
if (mGraph->inView(input.first) && !newMatches[i].graph->inView(input.first)) {
newMatches[i].graph->add(input.first, false);
if (!anchor.empty()) {
newMatches[i].anchors[type][anchor] = input.first;
}
newMatches[i].startNode = input.first;
found = true;
break;
}
}
}
}
if (found) {
++i;
}
else {
newMatches.erase(newMatches.begin() + i);
}
}
Log::debug("{}node {}{}, found: {}", std::string(2*newCtx.depth + 2, ' '), fmt::styled(type.empty() ? "." : type, fmt::fg(fmt::color::yellow)), anchor, newMatches.size());
}
ctx = newCtx;
matches = newMatches;
return true;
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/operator/MetaOperatorDefs.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/graph/Matching.hpp"
#include "aidge/recipes/Recipes.hpp"
using namespace Aidge;
TEST_CASE("[core/graph] Matching") {
auto g1 = Sequential({
Producer({16, 3, 512, 512}, "dataProvider"),
Conv(3, 4, {5, 5}, "conv1"),
ReLU("relu1"),
PaddedConv(4, 8, {5, 5}, "conv2", {1, 1}, {2, 2, 2, 2}),
ReLU("relu2"),
PaddedConv(8, 16, {5, 5}, "conv3", {1, 1}, {2, 2, 2, 2}),
ReLU("relu3"),
PaddedConv(8, 16, {5, 5}, "conv4", {1, 1}, {2, 2, 2, 2}),
Add(2, "add"),
PaddedConv(8, 16, {5, 5}, "conv5", {1, 1}, {2, 2, 2, 2}),
ReLU("relu5"),
Add(2, "add2")
});
g1->getNode("relu3")->addChild(g1->getNode("add"), 0, 1);
g1->getNode("conv5")->addChild(g1->getNode("add2"), 0, 1);
g1->updateInputsOutputs();
g1->save("Test_examples_before_expand", true);
expandMetaOps(g1);
g1->save("Test_examples", true);
SECTION("Conv->(ReLU->Pad->Conv)*") {
auto results = GraphMatching(g1).match("Conv->(ReLU->Pad->Conv)*");
REQUIRE(results.size() == 5);
for (auto result : results) {
std::vector<std::string> nodesName;
std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(),
std::back_inserter(nodesName),
[](auto val){ return val->name(); });
fmt::print("Found: {}\n", nodesName);
}
}
SECTION("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer") {
auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer");
REQUIRE(results.size() == 3);
for (auto result : results) {
std::vector<std::string> nodesName;
std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(),
std::back_inserter(nodesName),
[](auto val){ return val->name(); });
fmt::print("Found: {}\n", nodesName);
REQUIRE(result.graph->getNodes().size() == 5);
}
}
SECTION("Pad->Conv#->ReLU;(Conv#<*-Producer){2}") {
auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;(Conv#<*-Producer){2}");
REQUIRE(results.size() == 3);
for (auto result : results) {
std::vector<std::string> nodesName;
std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(),
std::back_inserter(nodesName),
[](auto val){ return val->name(); });
fmt::print("Found: {}\n", nodesName);
REQUIRE(result.graph->getNodes().size() == 5);
}
}
SECTION("Pad->Conv#->ReLU;(Conv#<*-.){2}") {
auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;(Conv#<*-.){2}");
REQUIRE(results.size() == 3);
for (auto result : results) {
std::vector<std::string> nodesName;
std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(),
std::back_inserter(nodesName),
[](auto val){ return val->name(); });
fmt::print("Found: {}\n", nodesName);
REQUIRE(result.graph->getNodes().size() == 5);
}
}
SECTION("Conv#->ReLU*;Conv#<-Pad*") {
auto results = GraphMatching(g1).match("Conv#->ReLU*;Conv#<-Pad*");
REQUIRE(results.size() == 5);
for (auto result : results) {
std::vector<std::string> nodesName;
std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(),
std::back_inserter(nodesName),
[](auto val){ return val->name(); });
fmt::print("Found: {}\n", nodesName);
REQUIRE((result.graph->getNodes().size() == 2 || result.graph->getNodes().size() == 3));
}
}
SECTION("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?") {
auto results = GraphMatching(g1).match("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?");
REQUIRE(results.size() == 5);
for (auto result : results) {
std::vector<std::string> nodesName;
std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(),
std::back_inserter(nodesName),
[](auto val){ return val->name(); });
fmt::print("Found: {}\n", nodesName);
}
}
SECTION("Conv#->ReLU?;Conv#<-Pad?") {
auto results = GraphMatching(g1).match("Conv#->ReLU?;Conv#<-Pad?");
REQUIRE(results.size() == 5);
for (auto result : results) {
std::vector<std::string> nodesName;
std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(),
std::back_inserter(nodesName),
[](auto val){ return val->name(); });
fmt::print("Found: {}\n", nodesName);
REQUIRE((result.graph->getNodes().size() == 2 || result.graph->getNodes().size() == 3));
}
}
SECTION("(Conv|ReLU)->Add") {
auto results = GraphMatching(g1).match("(Conv|ReLU)->Add");
REQUIRE(results.size() == 2);
for (auto result : results) {
std::vector<std::string> nodesName;
std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(),
std::back_inserter(nodesName),
[](auto val){ return val->name(); });
fmt::print("Found: {}\n", nodesName);
REQUIRE(result.graph->getNodes().size() == 2);
}
}
SECTION("Add<*-.") {
auto results = GraphMatching(g1).match("Add<*-.");
REQUIRE(results.size() == 2);
for (auto result : results) {
std::vector<std::string> nodesName;
std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(),
std::back_inserter(nodesName),
[](auto val){ return val->name(); });
fmt::print("Found: {}\n", nodesName);
REQUIRE(result.graph->getNodes().size() == 2);
}
}
SECTION("(Add#<*-.)+") {
auto results = GraphMatching(g1).match("(Add#<*-.)+");
REQUIRE(results.size() == 2);
for (auto result : results) {
std::vector<std::string> nodesName;
std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(),
std::back_inserter(nodesName),
[](auto val){ return val->name(); });
fmt::print("Found: {}\n", nodesName);
REQUIRE(result.graph->getNodes().size() == 3);
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment