diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp index 014c7e716054b129070968fb3ebb065248b3db68..bc831c7b9238455f96800700292af90b58ac0fa1 100644 --- a/unit_tests/graph/Test_Matching.cpp +++ b/unit_tests/graph/Test_Matching.cpp @@ -11,6 +11,8 @@ #include <catch2/catch_test_macros.hpp> +#include <fmt/chrono.h> + #include "aidge/data/Tensor.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Testing.hpp" @@ -219,4 +221,31 @@ TEST_CASE("[core/graph] Matching") { {"relu1", {"relu1", "conv2_pad"}} }); } + + SECTION("Conv->ReLU [perf]") { + const size_t nbTests = 10; + std::mt19937::result_type seed(1); + + for (int test = 0; test < nbTests; ++test) { + RandomGraph randGraph; + randGraph.types = {"Conv", "ReLU"}; + randGraph.typesWeights = {0.9, 0.1}; + const auto g1 = std::make_shared<GraphView>("g1"); + + Log::setConsoleLevel(Log::Warn); + g1->add(randGraph.gen(seed, 100)); + + auto gm = SinglePassGraphMatching(g1); + + const auto start = std::chrono::system_clock::now(); + const auto results = gm.match("Conv->ReLU"); + const auto end = std::chrono::system_clock::now(); + const auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start); + + REQUIRE(results.size() > 0); + ++seed; + + fmt::print("Found: {} - duration: {}\n", results.size(), duration); + } + } }