Skip to content
Snippets Groups Projects
Commit 36e3660a authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Improved testing

parent 6fd88c4d
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!138Alternative graph matching
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Testing.hpp"
#include "aidge/graph/OpArgs.hpp" #include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/ReLU.hpp" #include "aidge/operator/ReLU.hpp"
#include "aidge/operator/MetaOperatorDefs.hpp" #include "aidge/operator/MetaOperatorDefs.hpp"
...@@ -47,188 +48,130 @@ TEST_CASE("[core/graph] Matching") { ...@@ -47,188 +48,130 @@ TEST_CASE("[core/graph] Matching") {
g1->save("Test_examples", true); g1->save("Test_examples", true);
SECTION("Conv->(ReLU->Pad->Conv)*") { SECTION("Conv->(ReLU->Pad->Conv)*") {
auto results = GraphMatching(g1).match("Conv->(ReLU->Pad->Conv)*"); const auto results = GraphMatching(g1).match("Conv->(ReLU->Pad->Conv)*");
REQUIRE(results.size() == 5); REQUIRE(results.size() == 5);
for (auto result : results) { for (const auto& result : results) {
std::vector<std::string> nodesName; fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(),
std::back_inserter(nodesName),
[](auto val){ return val->name(); });
fmt::print("Found: {}\n", nodesName);
} }
} }
SECTION("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer") { SECTION("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer") {
auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer"); const auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer");
REQUIRE(results.size() == 3); REQUIRE(results.size() == 3);
for (auto result : results) { for (const auto& result : results) {
std::vector<std::string> nodesName; fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(),
std::back_inserter(nodesName),
[](auto val){ return val->name(); });
fmt::print("Found: {}\n", nodesName);
REQUIRE(result.graph->getNodes().size() == 5); REQUIRE(result.graph->getNodes().size() == 5);
} }
} }
SECTION("Pad->Conv#->ReLU;(Conv#<*-Producer){2}") { SECTION("Pad->Conv#->ReLU;(Conv#<*-Producer){2}") {
auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;(Conv#<*-Producer){2}"); const auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;(Conv#<*-Producer){2}");
REQUIRE(results.size() == 3); REQUIRE(results.size() == 3);
for (auto result : results) { for (const auto& result : results) {
std::vector<std::string> nodesName; fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(),
std::back_inserter(nodesName),
[](auto val){ return val->name(); });
fmt::print("Found: {}\n", nodesName);
REQUIRE(result.graph->getNodes().size() == 5); REQUIRE(result.graph->getNodes().size() == 5);
} }
} }
SECTION("Pad->Conv#->ReLU;(Conv#<*-.){2}") { SECTION("Pad->Conv#->ReLU;(Conv#<*-.){2}") {
auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;(Conv#<*-.){2}"); const auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;(Conv#<*-.){2}");
REQUIRE(results.size() == 3); REQUIRE(results.size() == 3);
for (auto result : results) { for (const auto& result : results) {
std::vector<std::string> nodesName; fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(),
std::back_inserter(nodesName),
[](auto val){ return val->name(); });
fmt::print("Found: {}\n", nodesName);
REQUIRE(result.graph->getNodes().size() == 5); REQUIRE(result.graph->getNodes().size() == 5);
} }
} }
SECTION("Conv#->ReLU*;Conv#<-Pad*") { SECTION("Conv#->ReLU*;Conv#<-Pad*") {
auto results = GraphMatching(g1).match("Conv#->ReLU*;Conv#<-Pad*"); const auto results = GraphMatching(g1).match("Conv#->ReLU*;Conv#<-Pad*");
REQUIRE(results.size() == 5); REQUIRE(results.size() == 5);
for (auto result : results) { for (const auto& result : results) {
std::vector<std::string> nodesName; fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(),
std::back_inserter(nodesName),
[](auto val){ return val->name(); });
fmt::print("Found: {}\n", nodesName);
REQUIRE((result.graph->getNodes().size() == 2 || result.graph->getNodes().size() == 3)); REQUIRE((result.graph->getNodes().size() == 2 || result.graph->getNodes().size() == 3));
} }
} }
SECTION("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?") { SECTION("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?") {
auto results = GraphMatching(g1).match("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?"); const auto results = GraphMatching(g1).match("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?");
REQUIRE(results.size() == 5); REQUIRE(results.size() == 5);
for (auto result : results) { for (const auto& result : results) {
std::vector<std::string> nodesName; fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(),
std::back_inserter(nodesName),
[](auto val){ return val->name(); });
fmt::print("Found: {}\n", nodesName);
} }
} }
SECTION("Conv#->ReLU?;Conv#<-Pad?") { SECTION("Conv#->ReLU?;Conv#<-Pad?") {
auto results = GraphMatching(g1).match("Conv#->ReLU?;Conv#<-Pad?"); const auto results = GraphMatching(g1).match("Conv#->ReLU?;Conv#<-Pad?");
REQUIRE(results.size() == 5); REQUIRE(results.size() == 5);
for (auto result : results) { for (const auto& result : results) {
std::vector<std::string> nodesName; fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(),
std::back_inserter(nodesName),
[](auto val){ return val->name(); });
fmt::print("Found: {}\n", nodesName);
REQUIRE((result.graph->getNodes().size() == 2 || result.graph->getNodes().size() == 3)); REQUIRE((result.graph->getNodes().size() == 2 || result.graph->getNodes().size() == 3));
} }
} }
SECTION("(Conv|ReLU)->Add") { SECTION("(Conv|ReLU)->Add") {
auto results = GraphMatching(g1).match("(Conv|ReLU)->Add"); const auto results = GraphMatching(g1).match("(Conv|ReLU)->Add");
REQUIRE(results.size() == 2); REQUIRE(results.size() == 2);
for (auto result : results) { for (const auto& result : results) {
std::vector<std::string> nodesName; fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(),
std::back_inserter(nodesName),
[](auto val){ return val->name(); });
fmt::print("Found: {}\n", nodesName);
REQUIRE(result.graph->getNodes().size() == 2); REQUIRE(result.graph->getNodes().size() == 2);
} }
} }
SECTION("Add<*-.") { SECTION("Add<*-.") {
auto results = GraphMatching(g1).match("Add<*-."); const auto results = GraphMatching(g1).match("Add<*-.");
REQUIRE(results.size() == 2); REQUIRE(results.size() == 2);
for (auto result : results) { for (const auto& result : results) {
std::vector<std::string> nodesName; fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(),
std::back_inserter(nodesName),
[](auto val){ return val->name(); });
fmt::print("Found: {}\n", nodesName);
REQUIRE(result.graph->getNodes().size() == 2); REQUIRE(result.graph->getNodes().size() == 2);
} }
} }
SECTION("(Add#<*-.)+") { SECTION("(Add#<*-.)+") {
auto results = GraphMatching(g1).match("(Add#<*-.)+"); const auto results = GraphMatching(g1).match("(Add#<*-.)+");
REQUIRE(results.size() == 2); REQUIRE(results.size() == 2);
for (auto result : results) { for (const auto& result : results) {
std::vector<std::string> nodesName; fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(),
std::back_inserter(nodesName),
[](auto val){ return val->name(); });
fmt::print("Found: {}\n", nodesName);
REQUIRE(result.graph->getNodes().size() == 3); REQUIRE(result.graph->getNodes().size() == 3);
} }
} }
SECTION("Conv-*>(ReLU&Add)") { SECTION("Conv-*>(ReLU&Add)") {
auto results = GraphMatching(g1).match("Conv-*>(ReLU&Add)"); const auto results = GraphMatching(g1).match("Conv-*>(ReLU&Add)");
REQUIRE(results.size() == 1); REQUIRE(results.size() == 1);
for (auto result : results) { for (const auto& result : results) {
std::vector<std::string> nodesName; fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(),
std::back_inserter(nodesName),
[](auto val){ return val->name(); });
fmt::print("Found: {}\n", nodesName);
REQUIRE(result.graph->getNodes().size() == 3); REQUIRE(result.graph->getNodes().size() == 3);
} }
} }
SECTION("Conv->(ReLU&Add)") { SECTION("Conv->(ReLU&Add)") {
auto results = GraphMatching(g1).match("Conv->(ReLU&Add)"); const auto results = GraphMatching(g1).match("Conv->(ReLU&Add)");
REQUIRE(results.size() == 0); REQUIRE(results.size() == 0);
} }
SECTION("ReLU-*>((Pad->Conv-*>Add#)&Add#)") { SECTION("ReLU-*>((Pad->Conv-*>Add#)&Add#)") {
auto results = GraphMatching(g1).match("ReLU-*>((Pad->Conv-*>Add#)&Add#)"); const auto results = GraphMatching(g1).match("ReLU-*>((Pad->Conv-*>Add#)&Add#)");
REQUIRE(results.size() == 1); REQUIRE(results.size() == 1);
for (auto result : results) { for (const auto& result : results) {
std::vector<std::string> nodesName; fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName));
std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(),
std::back_inserter(nodesName),
[](auto val){ return val->name(); });
fmt::print("Found: {}\n", nodesName);
REQUIRE(result.graph->getNodes().size() == 4); REQUIRE(result.graph->getNodes().size() == 4);
} }
} }
SECTION("ReLU-*>((Pad->Conv-*>Add)&Add)") { SECTION("ReLU-*>((Pad->Conv-*>Add)&Add)") {
auto results = GraphMatching(g1).match("ReLU-*>((Pad->Conv-*>Add)&Add)"); const auto results = GraphMatching(g1).match("ReLU-*>((Pad->Conv-*>Add)&Add)");
REQUIRE(results.size() == 0); REQUIRE(results.size() == 0);
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment