Forked from
Eclipse Projects / aidge / aidge_core
697 commits behind the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Test_Matching.cpp 17.67 KiB
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <fmt/chrono.h>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Testing.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/BatchNorm.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/operator/MetaOperatorDefs.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/graph/Matching.hpp"
#include "aidge/recipes/Recipes.hpp"
using namespace Aidge;
void checkMatches(const std::set<SinglePassGraphMatching::MatchingResult>& results, const std::map<std::string, std::set<std::string>>& expected) {
CHECK(results.size() == expected.size());
for (const auto& result : results) {
const auto found = nodePtrTo(result.graph->getNodes(), nodePtrToName);
fmt::print("Found: {}\n", found);
const auto rootNode = result.graph->rootNode()->name();
const auto expectedSet = expected.at(rootNode);
REQUIRE(found == expectedSet);
}
}
TEST_CASE("[core/graph] Matching") {
auto g1 = Sequential({
Producer({16, 3, 512, 512}, "dataProvider"),
Conv(3, 4, {5, 5}, "conv1"),
ReLU("relu1"),
PaddedConv(4, 8, {5, 5}, "conv2", {1, 1}, {2, 2, 2, 2}),
ReLU("relu2"),
PaddedConv(8, 16, {3, 3}, "conv3", {1, 1}, {2, 2, 2, 2}),
ReLU("relu3"),
PaddedConv(16, 16, {5, 5}, "conv4", {1, 1}, {2, 2, 2, 2}),
Add("add"),
PaddedConv(16, 16, {5, 5}, "conv5", {1, 1}, {2, 2, 2, 2}),
ReLU("relu5"),
Add("add2")
});
g1->getNode("relu3")->addChild(g1->getNode("add"), 0, 1);
g1->getNode("conv5")->addChild(g1->getNode("add2"), 0, 1);
g1->updateInputsOutputs();
g1->save("Test_examples_before_expand", true);
expandMetaOps(g1);
g1->save("Test_examples", true);
SECTION("Conv2D->(ReLU->Pad2D->Conv2D)*") {
const auto results = SinglePassGraphMatching(g1).match("Conv2D->(ReLU->Pad2D->Conv2D)*");
checkMatches(results, {
{"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "relu1", "relu2"}},
{"conv2_conv", {"conv2_conv", "conv3_conv", "conv3_pad", "relu2"}},
{"conv3_conv", {"conv3_conv"}},
{"conv4_conv", {"conv4_conv"}},
{"conv5_conv", {"conv5_conv"}}
});
}
SECTION("Conv2D->ReLU;ReLU->Pad2D") {
REQUIRE_THROWS(SinglePassGraphMatching(g1).match("Conv2D->ReLU;ReLU->Pad2D"));
}
SECTION("Conv2D->ReLU#1;ReLU#2->Pad2D") {
REQUIRE_THROWS(SinglePassGraphMatching(g1).match("Conv2D->ReLU#1;ReLU#2->Pad2D"));
}
SECTION("Conv2D?->ReLU") {
REQUIRE_THROWS(SinglePassGraphMatching(g1).match("Conv2D?->ReLU"));
}
SECTION("(Add#<*~.)*") {
REQUIRE_THROWS(SinglePassGraphMatching(g1).match("(Add#<*~.)*"));
}
SECTION("Conv2D->(ReLU~>Pad2D->Conv2D)*") {
const auto results = SinglePassGraphMatching(g1).match("Conv2D->(ReLU~>Pad2D->Conv2D)*");
checkMatches(results, {
{"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu1", "relu2", "relu3"}},
{"conv2_conv", {"conv2_conv", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu2", "relu3"}},
{"conv3_conv", {"conv3_conv", "conv4_conv", "conv4_pad", "relu3"}},
{"conv4_conv", {"conv4_conv"}},
{"conv5_conv", {"conv5_conv"}}
});
}
SECTION("Conv2D->(ReLU~>Pad2D->Conv2D)* [disjoint]") {
const auto results = SinglePassGraphMatching(g1).match("Conv2D->(ReLU~>Pad2D->Conv2D)*", true);
checkMatches(results, {
{"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu1", "relu2", "relu3"}},
{"conv5_conv", {"conv5_conv"}}
});
}
SECTION("Conv~>(ReLU~>Pad2D->Conv2D)*") {
const auto results = SinglePassGraphMatching(g1).match("Conv2D~>(ReLU~>Pad2D->Conv2D)*");
checkMatches(results, {
{"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu1", "relu2", "relu3"}},
{"conv2_conv", {"conv2_conv", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu2", "relu3"}},
{"conv3_conv", {"conv3_conv", "conv4_conv", "conv4_pad", "relu3"}},
{"conv4_conv", {"conv4_conv"}},
{"conv5_conv", {"conv5_conv"}}
});
}
SECTION("Pad2D->Conv2D#->ReLU;Conv2D#<1-Producer;Conv2D#<2-Producer") {
const auto results = SinglePassGraphMatching(g1).match("Pad2D->Conv2D#->ReLU;Conv2D#<1-Producer;Conv2D#<2-Producer");
checkMatches(results, {
{"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
{"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}
});
}
SECTION("Pad2D->Conv2D#~>ReLU;Conv2D#<1-Producer;Conv2D#<2-Producer") {
const auto results = SinglePassGraphMatching(g1).match("Pad2D->Conv2D#~>ReLU;Conv2D#<1-Producer;Conv2D#<2-Producer");
checkMatches(results, {
{"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
{"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}},
{"conv5_pad", {"conv5_b", "conv5_conv", "conv5_pad", "conv5_w", "relu5"}}
});
}
SECTION("Pad2D->Conv2D#~>ReLU;(Conv2D#<*-Producer){2}") {
const auto results = SinglePassGraphMatching(g1).match("Pad2D->Conv2D#~>ReLU;(Conv2D#<*-Producer){2}");
checkMatches(results, {
{"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
{"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}},
{"conv5_pad", {"conv5_b", "conv5_conv", "conv5_pad", "conv5_w", "relu5"}}
});
}
SECTION("Pad2D->Conv2D#->ReLU;(Conv2D#<*-Producer){2}") {
const auto results = SinglePassGraphMatching(g1).match("Pad2D->Conv2D#->ReLU;(Conv2D#<*-Producer){2}");
checkMatches(results, {
{"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
{"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}
});
}
SECTION("Pad2D->Conv2D#~>ReLU;(Conv2D#<*-.){2}") {
const auto results = SinglePassGraphMatching(g1).match("Pad2D->Conv2D#~>ReLU;(Conv2D#<*-.){2}");
checkMatches(results, {
{"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
{"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}},
{"conv5_pad", {"conv5_b", "conv5_conv", "conv5_pad", "conv5_w", "relu5"}}
});
}
SECTION("Pad2D->Conv2D#->ReLU;(Conv2D#<*-.){2}") {
const auto results = SinglePassGraphMatching(g1).match("Pad2D->Conv2D#->ReLU;(Conv2D#<*-.){2}");
checkMatches(results, {
{"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
{"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}
});
}
SECTION("Conv#~>ReLU*;Conv#<-Pad*") {
const auto results = SinglePassGraphMatching(g1).match("Conv2D#~>ReLU*;Conv2D#<-Pad2D*");
checkMatches(results, {
{"conv1", {"conv1", "relu1"}},
{"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
{"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
{"conv4_conv", {"conv4_conv", "conv4_pad"}},
{"conv5_conv", {"conv5_conv", "conv5_pad", "relu5"}}
});
}
SECTION("Conv2D#->ReLU*;Conv2D#<-Pad2D*") {
const auto results = SinglePassGraphMatching(g1).match("Conv2D#->ReLU*;Conv2D#<-Pad2D*");
checkMatches(results, {
{"conv1", {"conv1", "relu1"}},
{"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
{"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
{"conv4_conv", {"conv4_conv", "conv4_pad"}},
{"conv5_conv", {"conv5_conv", "conv5_pad"}}
});
}
SECTION("Conv2D#->ReLU?-*>Add#1?->ReLU?;Conv2D#<-Pad2D?;(Add#1<*-.)?") {
const auto results = SinglePassGraphMatching(g1).match("Conv2D#->ReLU?-*>Add#1?->ReLU?;Conv2D#<-Pad2D?;(Add#1<*-.)?");
checkMatches(results, {
{"conv1", {"conv1", "relu1"}},
{"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
{"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
{"conv4_conv", {"add", "conv4_conv", "conv4_pad"}},
{"conv5_conv", {"conv5_conv", "conv5_pad"}}
});
}
SECTION("Conv2D#~>ReLU?-*>Add#1?~>ReLU?;Conv2D#<-Pad?;(Add#1<*-.)?") {
const auto results = SinglePassGraphMatching(g1).match("Conv2D#~>ReLU?-*>Add#1?~>ReLU?;Conv2D#<-Pad2D?;(Add#1<*-.)?");
checkMatches(results, {
{"conv1", {"conv1", "relu1"}},
{"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
{"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
{"conv4_conv", {"add", "conv4_conv", "conv4_pad"}},
{"conv5_conv", {"add2", "conv5_conv", "conv5_pad", "relu5"}}
});
}
SECTION("Conv2D#~>ReLU?~*>Add#1?~>ReLU?;Conv2D#<-Pad2D?;(Add#1<*~.)?") {
const auto results = SinglePassGraphMatching(g1).match("Conv2D#~>ReLU?~*>Add#1?~>ReLU?;Conv2D#<-Pad2D?;(Add#1<*~.)?");
checkMatches(results, {
{"conv1", {"conv1", "relu1"}},
{"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
{"conv3_conv", {"add", "conv3_conv", "conv3_pad", "conv4_conv", "relu3"}},
{"conv4_conv", {"add", "conv4_conv", "conv4_pad", "relu3"}},
{"conv5_conv", {"add2", "conv5_conv", "conv5_pad", "relu5"}}
});
}
SECTION("Conv2D#->ReLU?;Conv2D#<-Pad2D?") {
const auto results = SinglePassGraphMatching(g1).match("Conv2D#->ReLU?;Conv2D#<-Pad2D?");
checkMatches(results, {
{"conv1", {"conv1", "relu1"}},
{"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
{"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
{"conv4_conv", {"conv4_conv", "conv4_pad"}},
{"conv5_conv", {"conv5_conv", "conv5_pad"}}
});
}
SECTION("Conv2D#~>ReLU?;Conv2D#<-Pad2D?") {
const auto results = SinglePassGraphMatching(g1).match("Conv2D#~>ReLU?;Conv2D#<-Pad2D?");
checkMatches(results, {
{"conv1", {"conv1", "relu1"}},
{"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
{"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
{"conv4_conv", {"conv4_conv", "conv4_pad"}},
{"conv5_conv", {"conv5_conv", "conv5_pad", "relu5"}}
});
}
SECTION("(Conv2D|ReLU)->Add") {
const auto results = SinglePassGraphMatching(g1).match("(Conv2D|ReLU)->Add");
checkMatches(results, {
{"conv4_conv", {"add", "conv4_conv"}},
{"relu5", {"add2", "relu5"}}
});
}
SECTION("Add<*-.") {
const auto results = SinglePassGraphMatching(g1).match("Add<*-.");
checkMatches(results, {
{"add", {"add", "conv4_conv"}},
{"add2", {"add2", "relu5"}}
});
}
SECTION("(Add#<*~.)+") {
const auto results = SinglePassGraphMatching(g1).match("(Add#<*~.)+");
checkMatches(results, {
{"add", {"add", "conv4_conv", "relu3"}},
{"add2", {"add2", "conv5_conv", "relu5"}}
});
}
SECTION("Conv2D~*>(ReLU&Add)") {
const auto results = SinglePassGraphMatching(g1).match("Conv2D~*>(ReLU&Add)");
checkMatches(results, {
{"conv5_conv", {"add2", "conv5_conv", "relu5"}}
});
}
SECTION("Conv2D~>(ReLU&Add)") {
const auto results = SinglePassGraphMatching(g1).match("Conv2D~>(ReLU&Add)");
REQUIRE(results.size() == 0);
}
SECTION("ReLU~*>((Pad2D->Conv2D-*>Add#)&Add#)") {
const auto results = SinglePassGraphMatching(g1).match("ReLU~*>((Pad2D->Conv2D-*>Add#)&Add#)");
checkMatches(results, {
{"relu3", {"add", "conv4_conv", "conv4_pad", "relu3"}}
});
}
SECTION("ReLU-*>((Pad2D->Conv2D-*>Add)&Add)") {
const auto results = SinglePassGraphMatching(g1).match("ReLU-*>((Pad2D->Conv2D-*>Add)&Add)");
REQUIRE(results.size() == 0);
}
SECTION("Pad2D->Conv2D[3x3]->ReLU") {
auto gm = SinglePassGraphMatching(g1);
gm.addNodeLambda("3x3", [](const NodePtr& node) {
const std::shared_ptr<Conv_Op<2>> op =
std::static_pointer_cast<Conv_Op<2>>(node->getOperator());
return (op->kernelDims() == std::array<DimSize_t, 2>({3, 3}));
});
const auto results = gm.match("Pad2D->Conv2D[3x3]->ReLU");
checkMatches(results, {
{"conv3_pad", {"conv3_conv", "conv3_pad", "relu3"}}
});
}
SECTION(".[test]->Pad2D") {
auto gm = SinglePassGraphMatching(g1);
gm.addNodeLambda("test", [](const NodePtr& node) {
return (node->type() == "Add" || (node->type() == "ReLU" && node->name() == "relu1"));
});
const auto results = gm.match(".[test]->Pad2D");
checkMatches(results, {
{"add", {"add", "conv5_pad"}},
{"relu1", {"relu1", "conv2_pad"}}
});
}
auto g2 = Sequential({
Producer({16, 3, 512, 512}, "dataProvider"),
Conv(3, 4, {5, 5}, "conv1"),
BatchNorm<2>(4, 1.0e-5, 0.1, false, "bn1"),
Conv(4, 4, {5, 5}, "conv2"),
ReLU("relu2"),
Conv(4, 4, {5, 5}, "conv3"),
BatchNorm<2>(4, 1.0e-5, 0.1, false, "bn3"),
FC(4, 4, false, "fc1"),
FC(4, 4, false, "fc2"),
FC(4, 4, false, "fc3"),
ReLU("relu3"),
Conv(1, 4, {5, 5}, "conv4")
});
SECTION("((Conv2D#->(.[exBN]|$))|(FC#->(.[exFC])*->$))") {
auto gm = SinglePassGraphMatching(g2);
gm.addNodeLambda("exBN", [](const NodePtr& node) {
return (node->type() != "BatchNorm2D");
});
gm.addNodeLambda("exFC", [](const NodePtr& node) {
return (node->type() != "FC");
});
const auto results = gm.match("((Conv2D#->(.[exBN]|$))|(FC#->(.[exFC])*->$))");
checkMatches(results, {
{"conv2", {"conv2", "relu2"}},
{"conv4", {"conv4"}},
{"fc3", {"fc3", "relu3", "conv4"}}
});
}
// Find last node of a type
SECTION("FC#->(.[exFC])*->$") {
auto gm = SinglePassGraphMatching(g2);
gm.addNodeLambda("exFC", [](const NodePtr& node) {
return (node->type() != "FC");
});
const auto results = gm.match("FC#->(.[exFC])*->$");
checkMatches(results, {
{"fc3", {"fc3", "relu3", "conv4"}}
});
}
SECTION("Conv2D#->(.[exConv])*->$") {
auto gm = SinglePassGraphMatching(g2);
gm.addNodeLambda("exConv", [](const NodePtr& node) {
return (node->type() != "Conv2D");
});
const auto results = gm.match("Conv2D#->(.[exConv])*->$");
checkMatches(results, {
{"conv4", {"conv4"}}
});
}
// Find first node of a type
SECTION("FC#<-(.[exFC])*<-$") {
auto gm = SinglePassGraphMatching(g2);
gm.addNodeLambda("exFC", [](const NodePtr& node) {
return (node->type() != "FC");
});
const auto results = gm.match("FC#<-(.[exFC])*<-$");
checkMatches(results, {
{"fc1", {"fc1", "bn3", "conv3", "relu2", "conv2", "bn1", "conv1", "dataProvider"}}
});
}
SECTION("(((FC#|Conv2D#)<-(.[exParam])*<-$)|((FC#|Conv2D#)->(.[exParam])*->$));(FC#|Conv2D#)<1-Producer#") {
auto gm = SinglePassGraphMatching(g2);
gm.addNodeLambda("exParam", [](const NodePtr& node) {
return (node->type() != "FC" && node->type() != "Conv2D");
});
const auto results = gm.match("(((FC#|Conv2D#)<-(.[exParam])*<-$)|((FC#|Conv2D#)->(.[exParam])*->$));(FC#|Conv2D#)<1-Producer#");
checkMatches(results, {
{"conv1", {"conv1", "conv1_w", "dataProvider"}},
{"conv4", {"conv4", "conv4_w"}}
});
}
SECTION("Conv2D->ReLU [perf]") {
const size_t nbTests = 3;
std::mt19937::result_type seed(1);
for (int test = 0; test < nbTests; ++test) {
RandomGraph randGraph;
randGraph.types = {"Conv2D", "ReLU", "Dummy"};
randGraph.typesWeights = {0.4, 0.4, 0.2};
randGraph.avgIn = 1;
randGraph.maxIn = 1;
randGraph.maxOut = 1;
randGraph.avgOut = 1;
randGraph.density = 0.9;
randGraph.acyclic = true;
const auto g1 = std::make_shared<GraphView>("g1");
Log::setConsoleLevel(Log::Warn);
g1->add(randGraph.gen(seed, 100));
g1->save("graph_single_pass");
auto gm = SinglePassGraphMatching(g1);
const auto start = std::chrono::system_clock::now();
const auto results = gm.match("Conv2D->ReLU#;ReLU#->Dummy");
const auto end = std::chrono::system_clock::now();
const auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
REQUIRE(results.size() > 0);
++seed;
fmt::print("Found: {} - duration: {}\n", results.size(), duration);
}
}
}