diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp index a013eace8350223315e3e5a468b027701c9886c3..10d2da45a2f46df45f5defd6b3c16ce4f46c3c17 100644 --- a/unit_tests/graph/Test_Matching.cpp +++ b/unit_tests/graph/Test_Matching.cpp @@ -13,6 +13,7 @@ #include "aidge/data/Tensor.hpp" #include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Testing.hpp" #include "aidge/graph/OpArgs.hpp" #include "aidge/operator/ReLU.hpp" #include "aidge/operator/MetaOperatorDefs.hpp" @@ -47,188 +48,130 @@ TEST_CASE("[core/graph] Matching") { g1->save("Test_examples", true); SECTION("Conv->(ReLU->Pad->Conv)*") { - auto results = GraphMatching(g1).match("Conv->(ReLU->Pad->Conv)*"); + const auto results = GraphMatching(g1).match("Conv->(ReLU->Pad->Conv)*"); REQUIRE(results.size() == 5); - for (auto result : results) { - std::vector<std::string> nodesName; - std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(), - std::back_inserter(nodesName), - [](auto val){ return val->name(); }); - fmt::print("Found: {}\n", nodesName); + for (const auto& result : results) { + fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); } } SECTION("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer") { - auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer"); + const auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;Conv#<1-Producer;Conv#<2-Producer"); REQUIRE(results.size() == 3); - for (auto result : results) { - std::vector<std::string> nodesName; - std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(), - std::back_inserter(nodesName), - [](auto val){ return val->name(); }); - fmt::print("Found: {}\n", nodesName); - + for (const auto& result : results) { + fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); REQUIRE(result.graph->getNodes().size() == 5); } } SECTION("Pad->Conv#->ReLU;(Conv#<*-Producer){2}") { - auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;(Conv#<*-Producer){2}"); + const auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;(Conv#<*-Producer){2}"); REQUIRE(results.size() == 3); - for (auto result : results) { - std::vector<std::string> nodesName; - std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(), - std::back_inserter(nodesName), - [](auto val){ return val->name(); }); - fmt::print("Found: {}\n", nodesName); - + for (const auto& result : results) { + fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); REQUIRE(result.graph->getNodes().size() == 5); } } SECTION("Pad->Conv#->ReLU;(Conv#<*-.){2}") { - auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;(Conv#<*-.){2}"); + const auto results = GraphMatching(g1).match("Pad->Conv#->ReLU;(Conv#<*-.){2}"); REQUIRE(results.size() == 3); - for (auto result : results) { - std::vector<std::string> nodesName; - std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(), - std::back_inserter(nodesName), - [](auto val){ return val->name(); }); - fmt::print("Found: {}\n", nodesName); - + for (const auto& result : results) { + fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); REQUIRE(result.graph->getNodes().size() == 5); } } SECTION("Conv#->ReLU*;Conv#<-Pad*") { - auto results = GraphMatching(g1).match("Conv#->ReLU*;Conv#<-Pad*"); + const auto results = GraphMatching(g1).match("Conv#->ReLU*;Conv#<-Pad*"); REQUIRE(results.size() == 5); - for (auto result : results) { - std::vector<std::string> nodesName; - std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(), - std::back_inserter(nodesName), - [](auto val){ return val->name(); }); - fmt::print("Found: {}\n", nodesName); - + for (const auto& result : results) { + fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); REQUIRE((result.graph->getNodes().size() == 2 || result.graph->getNodes().size() == 3)); } } SECTION("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?") { - auto results = GraphMatching(g1).match("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?"); + const auto results = GraphMatching(g1).match("Conv#->ReLU?-*>Add#1?->ReLU?;Conv#<-Pad?;(Add#1<*-.)?"); REQUIRE(results.size() == 5); - for (auto result : results) { - std::vector<std::string> nodesName; - std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(), - std::back_inserter(nodesName), - [](auto val){ return val->name(); }); - fmt::print("Found: {}\n", nodesName); + for (const auto& result : results) { + fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); } } SECTION("Conv#->ReLU?;Conv#<-Pad?") { - auto results = GraphMatching(g1).match("Conv#->ReLU?;Conv#<-Pad?"); + const auto results = GraphMatching(g1).match("Conv#->ReLU?;Conv#<-Pad?"); REQUIRE(results.size() == 5); - for (auto result : results) { - std::vector<std::string> nodesName; - std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(), - std::back_inserter(nodesName), - [](auto val){ return val->name(); }); - fmt::print("Found: {}\n", nodesName); - + for (const auto& result : results) { + fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); REQUIRE((result.graph->getNodes().size() == 2 || result.graph->getNodes().size() == 3)); } } SECTION("(Conv|ReLU)->Add") { - auto results = GraphMatching(g1).match("(Conv|ReLU)->Add"); + const auto results = GraphMatching(g1).match("(Conv|ReLU)->Add"); REQUIRE(results.size() == 2); - for (auto result : results) { - std::vector<std::string> nodesName; - std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(), - std::back_inserter(nodesName), - [](auto val){ return val->name(); }); - fmt::print("Found: {}\n", nodesName); - + for (const auto& result : results) { + fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); REQUIRE(result.graph->getNodes().size() == 2); } } SECTION("Add<*-.") { - auto results = GraphMatching(g1).match("Add<*-."); + const auto results = GraphMatching(g1).match("Add<*-."); REQUIRE(results.size() == 2); - for (auto result : results) { - std::vector<std::string> nodesName; - std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(), - std::back_inserter(nodesName), - [](auto val){ return val->name(); }); - fmt::print("Found: {}\n", nodesName); - + for (const auto& result : results) { + fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); REQUIRE(result.graph->getNodes().size() == 2); } } SECTION("(Add#<*-.)+") { - auto results = GraphMatching(g1).match("(Add#<*-.)+"); + const auto results = GraphMatching(g1).match("(Add#<*-.)+"); REQUIRE(results.size() == 2); - for (auto result : results) { - std::vector<std::string> nodesName; - std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(), - std::back_inserter(nodesName), - [](auto val){ return val->name(); }); - fmt::print("Found: {}\n", nodesName); - + for (const auto& result : results) { + fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); REQUIRE(result.graph->getNodes().size() == 3); } } SECTION("Conv-*>(ReLU&Add)") { - auto results = GraphMatching(g1).match("Conv-*>(ReLU&Add)"); + const auto results = GraphMatching(g1).match("Conv-*>(ReLU&Add)"); REQUIRE(results.size() == 1); - for (auto result : results) { - std::vector<std::string> nodesName; - std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(), - std::back_inserter(nodesName), - [](auto val){ return val->name(); }); - fmt::print("Found: {}\n", nodesName); - + for (const auto& result : results) { + fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); REQUIRE(result.graph->getNodes().size() == 3); } } SECTION("Conv->(ReLU&Add)") { - auto results = GraphMatching(g1).match("Conv->(ReLU&Add)"); + const auto results = GraphMatching(g1).match("Conv->(ReLU&Add)"); REQUIRE(results.size() == 0); } SECTION("ReLU-*>((Pad->Conv-*>Add#)&Add#)") { - auto results = GraphMatching(g1).match("ReLU-*>((Pad->Conv-*>Add#)&Add#)"); + const auto results = GraphMatching(g1).match("ReLU-*>((Pad->Conv-*>Add#)&Add#)"); REQUIRE(results.size() == 1); - for (auto result : results) { - std::vector<std::string> nodesName; - std::transform(result.graph->getNodes().begin(), result.graph->getNodes().end(), - std::back_inserter(nodesName), - [](auto val){ return val->name(); }); - fmt::print("Found: {}\n", nodesName); - + for (const auto& result : results) { + fmt::print("Found: {}\n", nodePtrTo(result.graph->getNodes(), nodePtrToName)); REQUIRE(result.graph->getNodes().size() == 4); } } SECTION("ReLU-*>((Pad->Conv-*>Add)&Add)") { - auto results = GraphMatching(g1).match("ReLU-*>((Pad->Conv-*>Add)&Add)"); + const auto results = GraphMatching(g1).match("ReLU-*>((Pad->Conv-*>Add)&Add)"); REQUIRE(results.size() == 0); } }