diff --git a/unit_tests/recipies/Test_MatMulTiling.cpp b/unit_tests/recipies/Test_MatMulTiling.cpp index 4920dc635929fb4f5f31ccbaae43fc5a589c9a87..46d5418fd557fbb716f7e1d9c54eb76d94b0061e 100644 --- a/unit_tests/recipies/Test_MatMulTiling.cpp +++ b/unit_tests/recipies/Test_MatMulTiling.cpp @@ -22,6 +22,7 @@ #include "aidge/operator/Producer.hpp" #include "aidge/graph/OpArgs.hpp" #include "aidge/scheduler/SequentialScheduler.hpp" +#include "aidge/graph/Matching.hpp" #include "aidge/utils/TensorUtils.hpp" using namespace Aidge; @@ -79,10 +80,22 @@ TEST_CASE("[MatMulTiling]") { // Tiling fmt::println("Tiling"); matMulTiling(matmul1, {16, 16}); + removeIdentity(g1); g1->setBackend("cpu"); g1->save("MatMulSplitting_graph_split"); + auto gm = SinglePassGraphMatching(g1); + gm.addNodeLambda("16x16", [](const NodePtr& node) { + const auto op = + std::static_pointer_cast<OperatorTensor>(node->getOperator()); + const auto dims = op->getOutput(0)->dims(); + return (dims.end()[-2] == 16 && dims.end()[-1] == 16); + }); + + const auto results = gm.match("MatMul[16x16]"); + REQUIRE(results.size() == 25); + // Check result fmt::println("Schedule forward tiled graph"); s1 = SequentialScheduler(g1);