From e87dd5bdeda6164b6f908818b7d69b39c436638d Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 8 Nov 2024 10:15:06 +0100 Subject: [PATCH] Working 2D tiling --- unit_tests/recipies/Test_MatMulTiling.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/unit_tests/recipies/Test_MatMulTiling.cpp b/unit_tests/recipies/Test_MatMulTiling.cpp index 4920dc63..46d5418f 100644 --- a/unit_tests/recipies/Test_MatMulTiling.cpp +++ b/unit_tests/recipies/Test_MatMulTiling.cpp @@ -22,6 +22,7 @@ #include "aidge/operator/Producer.hpp" #include "aidge/graph/OpArgs.hpp" #include "aidge/scheduler/SequentialScheduler.hpp" +#include "aidge/graph/Matching.hpp" #include "aidge/utils/TensorUtils.hpp" using namespace Aidge; @@ -79,10 +80,22 @@ TEST_CASE("[MatMulTiling]") { // Tiling fmt::println("Tiling"); matMulTiling(matmul1, {16, 16}); + removeIdentity(g1); g1->setBackend("cpu"); g1->save("MatMulSplitting_graph_split"); + auto gm = SinglePassGraphMatching(g1); + gm.addNodeLambda("16x16", [](const NodePtr& node) { + const auto op = + std::static_pointer_cast<OperatorTensor>(node->getOperator()); + const auto dims = op->getOutput(0)->dims(); + return (dims.end()[-2] == 16 && dims.end()[-1] == 16); + }); + + const auto results = gm.match("MatMul[16x16]"); + REQUIRE(results.size() == 25); + // Check result fmt::println("Schedule forward tiled graph"); s1 = SequentialScheduler(g1); -- GitLab