Skip to content
Snippets Groups Projects
Commit 693cb6ff authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Working prototype

parent 2ed4c450
No related branches found
No related tags found
3 merge requests!118v0.4.0,!108v0.4.0,!105Add MatMulTiling recipe
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cstddef>
#include <random> // std::random_device, std::mt19937, std::uniform_real_distribution
#include <catch2/catch_test_macros.hpp>
#include "aidge/recipes/Recipes.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/AvgPooling.hpp"
#include "aidge/operator/MaxPooling.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/utils/TensorUtils.hpp"
using namespace Aidge;
TEST_CASE("[MatMulTiling]") {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> valueDist(-1.0f, 1.0f);
auto dataProvider = Producer({2, 3, 80, 80}, "dataProvider");
auto w1 = Producer({2, 3, 80, 80}, "w1");
auto matmul1 = MatMul("matmul1");
auto w2 = Producer({2, 3, 80, 80}, "w1");
auto matmul2 = MatMul("matmul2");
auto w3 = Producer({2, 3, 80, 80}, "w1");
auto matmul3 = MatMul("matmul3");
dataProvider->addChild(matmul1, 0, 0);
w1->addChild(matmul1, 0, 1);
matmul1->addChild(matmul2, 0, 0);
w2->addChild(matmul2, 0, 1);
matmul2->addChild(matmul3, 0, 0);
w3->addChild(matmul3, 0, 1);
auto g1 = getConnectedGraphView(matmul1);
g1->setBackend("cpu");
g1->forwardDims();
g1->save("MatMulSplitting_graph");
// Fill random values
fmt::println("Fill random values");
auto tData = std::static_pointer_cast<OperatorTensor>(dataProvider->getOperator())->getOutput(0);
for (size_t i = 0; i < tData->size(); ++i) {
tData->set<float>(i, valueDist(gen));
}
auto tw1 = std::static_pointer_cast<OperatorTensor>(w1->getOperator())->getOutput(0);
for (size_t i = 0; i < tw1->size(); ++i) {
tw1->set<float>(i, valueDist(gen));
}
auto tw2 = std::static_pointer_cast<OperatorTensor>(w2->getOperator())->getOutput(0);
for (size_t i = 0; i < tw2->size(); ++i) {
tw2->set<float>(i, valueDist(gen));
}
auto tw3 = std::static_pointer_cast<OperatorTensor>(w3->getOperator())->getOutput(0);
for (size_t i = 0; i < tw3->size(); ++i) {
tw3->set<float>(i, valueDist(gen));
}
fmt::println("Schedule forward graph");
auto s1 = SequentialScheduler(g1);
s1.forward();
const auto tOut = std::static_pointer_cast<OperatorTensor>(g1->getOrderedOutputs()[0].first->getOperator())->getOutput(0)->clone();
// Tiling
fmt::println("Tiling");
matMulTiling(matmul1, {16, 16});
g1->setBackend("cpu");
g1->save("MatMulSplitting_graph_split");
// Check result
fmt::println("Schedule forward tiled graph");
s1 = SequentialScheduler(g1);
s1.resetScheduling();
s1.forward();
const auto tOutTiled = std::static_pointer_cast<OperatorTensor>(g1->getOrderedOutputs()[0].first->getOperator())->getOutput(0)->clone();
REQUIRE(approxEq<float>(tOut, tOutTiled));
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment