From 693cb6ffa0901c2f7aded5346ed87fd77d02ffd3 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Thu, 7 Nov 2024 16:01:22 +0100 Subject: [PATCH] Working prototype --- unit_tests/recipies/Test_MatMulTiling.cpp | 94 +++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 unit_tests/recipies/Test_MatMulTiling.cpp diff --git a/unit_tests/recipies/Test_MatMulTiling.cpp b/unit_tests/recipies/Test_MatMulTiling.cpp new file mode 100644 index 00000000..4920dc63 --- /dev/null +++ b/unit_tests/recipies/Test_MatMulTiling.cpp @@ -0,0 +1,94 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cstddef> +#include <random> // std::random_device, std::mt19937, std::uniform_real_distribution + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/recipes/Recipes.hpp" +#include "aidge/operator/MatMul.hpp" +#include "aidge/operator/AvgPooling.hpp" +#include "aidge/operator/MaxPooling.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/graph/OpArgs.hpp" +#include "aidge/scheduler/SequentialScheduler.hpp" +#include "aidge/utils/TensorUtils.hpp" + +using namespace Aidge; + +TEST_CASE("[MatMulTiling]") { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution<float> valueDist(-1.0f, 1.0f); + + auto dataProvider = Producer({2, 3, 80, 80}, "dataProvider"); + auto w1 = Producer({2, 3, 80, 80}, "w1"); + auto matmul1 = MatMul("matmul1"); + auto w2 = Producer({2, 3, 80, 80}, "w1"); + auto matmul2 = MatMul("matmul2"); + auto w3 = Producer({2, 3, 80, 80}, "w1"); + auto matmul3 = MatMul("matmul3"); + + dataProvider->addChild(matmul1, 0, 0); + w1->addChild(matmul1, 0, 1); + matmul1->addChild(matmul2, 0, 0); + w2->addChild(matmul2, 0, 1); + matmul2->addChild(matmul3, 0, 0); + w3->addChild(matmul3, 0, 1); + + auto g1 = getConnectedGraphView(matmul1); + g1->setBackend("cpu"); + g1->forwardDims(); + g1->save("MatMulSplitting_graph"); + + // Fill random values + fmt::println("Fill random values"); + auto tData = std::static_pointer_cast<OperatorTensor>(dataProvider->getOperator())->getOutput(0); + for (size_t i = 0; i < tData->size(); ++i) { + tData->set<float>(i, valueDist(gen)); + } + auto tw1 = std::static_pointer_cast<OperatorTensor>(w1->getOperator())->getOutput(0); + for (size_t i = 0; i < tw1->size(); ++i) { + tw1->set<float>(i, valueDist(gen)); + } + auto tw2 = std::static_pointer_cast<OperatorTensor>(w2->getOperator())->getOutput(0); + for (size_t i = 0; i < tw2->size(); ++i) { + tw2->set<float>(i, valueDist(gen)); + } + auto tw3 = std::static_pointer_cast<OperatorTensor>(w3->getOperator())->getOutput(0); + for (size_t i = 0; i < tw3->size(); ++i) { + tw3->set<float>(i, valueDist(gen)); + } + + fmt::println("Schedule forward graph"); + auto s1 = SequentialScheduler(g1); + s1.forward(); + + const auto tOut = std::static_pointer_cast<OperatorTensor>(g1->getOrderedOutputs()[0].first->getOperator())->getOutput(0)->clone(); + + // Tiling + fmt::println("Tiling"); + matMulTiling(matmul1, {16, 16}); + + g1->setBackend("cpu"); + g1->save("MatMulSplitting_graph_split"); + + // Check result + fmt::println("Schedule forward tiled graph"); + s1 = SequentialScheduler(g1); + s1.resetScheduling(); + s1.forward(); + + const auto tOutTiled = std::static_pointer_cast<OperatorTensor>(g1->getOrderedOutputs()[0].first->getOperator())->getOutput(0)->clone(); + REQUIRE(approxEq<float>(tOut, tOutTiled)); +} -- GitLab