From 693cb6ffa0901c2f7aded5346ed87fd77d02ffd3 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Thu, 7 Nov 2024 16:01:22 +0100
Subject: [PATCH] Working prototype

---
 unit_tests/recipies/Test_MatMulTiling.cpp | 94 +++++++++++++++++++++++
 1 file changed, 94 insertions(+)
 create mode 100644 unit_tests/recipies/Test_MatMulTiling.cpp

diff --git a/unit_tests/recipies/Test_MatMulTiling.cpp b/unit_tests/recipies/Test_MatMulTiling.cpp
new file mode 100644
index 00000000..4920dc63
--- /dev/null
+++ b/unit_tests/recipies/Test_MatMulTiling.cpp
@@ -0,0 +1,94 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <cstddef>
+#include <random>  // std::random_device, std::mt19937, std::uniform_real_distribution
+
+#include <catch2/catch_test_macros.hpp>
+
+#include "aidge/recipes/Recipes.hpp"
+#include "aidge/operator/MatMul.hpp"
+#include "aidge/operator/AvgPooling.hpp"
+#include "aidge/operator/MaxPooling.hpp"
+#include "aidge/operator/GenericOperator.hpp"
+#include "aidge/operator/Producer.hpp"
+#include "aidge/graph/OpArgs.hpp"
+#include "aidge/scheduler/SequentialScheduler.hpp"
+#include "aidge/utils/TensorUtils.hpp"
+
+using namespace Aidge;
+
+TEST_CASE("[MatMulTiling]") {
+    std::random_device rd;
+    std::mt19937 gen(rd());
+    std::uniform_real_distribution<float> valueDist(-1.0f, 1.0f);
+
+    auto dataProvider = Producer({2, 3, 80, 80}, "dataProvider");
+    auto w1 = Producer({2, 3, 80, 80}, "w1");
+    auto matmul1 = MatMul("matmul1");
+    auto w2 = Producer({2, 3, 80, 80}, "w1");
+    auto matmul2 = MatMul("matmul2");
+    auto w3 = Producer({2, 3, 80, 80}, "w1");
+    auto matmul3 = MatMul("matmul3");
+
+    dataProvider->addChild(matmul1, 0, 0);
+    w1->addChild(matmul1, 0, 1);
+    matmul1->addChild(matmul2, 0, 0);
+    w2->addChild(matmul2, 0, 1);
+    matmul2->addChild(matmul3, 0, 0);
+    w3->addChild(matmul3, 0, 1);
+
+    auto g1 = getConnectedGraphView(matmul1);
+    g1->setBackend("cpu");
+    g1->forwardDims();
+    g1->save("MatMulSplitting_graph");
+
+    // Fill random values
+    fmt::println("Fill random values");
+    auto tData = std::static_pointer_cast<OperatorTensor>(dataProvider->getOperator())->getOutput(0);
+    for (size_t i = 0; i < tData->size(); ++i) {
+        tData->set<float>(i, valueDist(gen));
+    }
+    auto tw1 = std::static_pointer_cast<OperatorTensor>(w1->getOperator())->getOutput(0);
+    for (size_t i = 0; i < tw1->size(); ++i) {
+        tw1->set<float>(i, valueDist(gen));
+    }
+    auto tw2 = std::static_pointer_cast<OperatorTensor>(w2->getOperator())->getOutput(0);
+    for (size_t i = 0; i < tw2->size(); ++i) {
+        tw2->set<float>(i, valueDist(gen));
+    }
+    auto tw3 = std::static_pointer_cast<OperatorTensor>(w3->getOperator())->getOutput(0);
+    for (size_t i = 0; i < tw3->size(); ++i) {
+        tw3->set<float>(i, valueDist(gen));
+    }
+
+    fmt::println("Schedule forward graph");
+    auto s1 = SequentialScheduler(g1);
+    s1.forward();
+
+    const auto tOut = std::static_pointer_cast<OperatorTensor>(g1->getOrderedOutputs()[0].first->getOperator())->getOutput(0)->clone();
+
+    // Tiling
+    fmt::println("Tiling");
+    matMulTiling(matmul1, {16, 16});
+
+    g1->setBackend("cpu");
+    g1->save("MatMulSplitting_graph_split");
+
+    // Check result
+    fmt::println("Schedule forward tiled graph");
+    s1 = SequentialScheduler(g1);
+    s1.resetScheduling();
+    s1.forward();
+
+    const auto tOutTiled = std::static_pointer_cast<OperatorTensor>(g1->getOrderedOutputs()[0].first->getOperator())->getOutput(0)->clone();
+    REQUIRE(approxEq<float>(tOut, tOutTiled));
+}
-- 
GitLab