Skip to content
Snippets Groups Projects
Commit a2bcd467 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'matmultiling' into 'dev'

Add MatMulTiling recipe

See merge request !105
parents d318dbfe e87dd5bd
No related branches found
No related tags found
3 merge requests!118v0.4.0,!108v0.4.0,!105Add MatMulTiling recipe
Pipeline #58696 passed
......@@ -89,13 +89,13 @@ void SliceImpl_cpu_forward_kernel(const std::vector<std::int64_t>& starts,
}
REGISTRAR(SliceImpl_cpu,
{DataType::Float32},
{{DataType::Float32, DataType::Any}, {DataType::Float32}},
{ProdConso::inPlaceModel, Aidge::SliceImpl_cpu_forward_kernel<float, float>, nullptr});
REGISTRAR(SliceImpl_cpu,
{DataType::Float64},
{{DataType::Float64, DataType::Any}, {DataType::Float64}},
{ProdConso::inPlaceModel, Aidge::SliceImpl_cpu_forward_kernel<double, double>, nullptr});
REGISTRAR(SliceImpl_cpu,
{DataType::Int32},
{{DataType::Int32, DataType::Any}, {DataType::Int32}},
{ProdConso::inPlaceModel, Aidge::SliceImpl_cpu_forward_kernel<int32_t, int32_t>, nullptr});
} // namespace Aidge
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cstddef>
#include <random> // std::random_device, std::mt19937, std::uniform_real_distribution
#include <catch2/catch_test_macros.hpp>
#include "aidge/recipes/Recipes.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/AvgPooling.hpp"
#include "aidge/operator/MaxPooling.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/graph/Matching.hpp"
#include "aidge/utils/TensorUtils.hpp"
using namespace Aidge;
TEST_CASE("[MatMulTiling]") {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> valueDist(-1.0f, 1.0f);
auto dataProvider = Producer({2, 3, 80, 80}, "dataProvider");
auto w1 = Producer({2, 3, 80, 80}, "w1");
auto matmul1 = MatMul("matmul1");
auto w2 = Producer({2, 3, 80, 80}, "w1");
auto matmul2 = MatMul("matmul2");
auto w3 = Producer({2, 3, 80, 80}, "w1");
auto matmul3 = MatMul("matmul3");
dataProvider->addChild(matmul1, 0, 0);
w1->addChild(matmul1, 0, 1);
matmul1->addChild(matmul2, 0, 0);
w2->addChild(matmul2, 0, 1);
matmul2->addChild(matmul3, 0, 0);
w3->addChild(matmul3, 0, 1);
auto g1 = getConnectedGraphView(matmul1);
g1->setBackend("cpu");
g1->forwardDims();
g1->save("MatMulSplitting_graph");
// Fill random values
fmt::println("Fill random values");
auto tData = std::static_pointer_cast<OperatorTensor>(dataProvider->getOperator())->getOutput(0);
for (size_t i = 0; i < tData->size(); ++i) {
tData->set<float>(i, valueDist(gen));
}
auto tw1 = std::static_pointer_cast<OperatorTensor>(w1->getOperator())->getOutput(0);
for (size_t i = 0; i < tw1->size(); ++i) {
tw1->set<float>(i, valueDist(gen));
}
auto tw2 = std::static_pointer_cast<OperatorTensor>(w2->getOperator())->getOutput(0);
for (size_t i = 0; i < tw2->size(); ++i) {
tw2->set<float>(i, valueDist(gen));
}
auto tw3 = std::static_pointer_cast<OperatorTensor>(w3->getOperator())->getOutput(0);
for (size_t i = 0; i < tw3->size(); ++i) {
tw3->set<float>(i, valueDist(gen));
}
fmt::println("Schedule forward graph");
auto s1 = SequentialScheduler(g1);
s1.forward();
const auto tOut = std::static_pointer_cast<OperatorTensor>(g1->getOrderedOutputs()[0].first->getOperator())->getOutput(0)->clone();
// Tiling
fmt::println("Tiling");
matMulTiling(matmul1, {16, 16});
removeIdentity(g1);
g1->setBackend("cpu");
g1->save("MatMulSplitting_graph_split");
auto gm = SinglePassGraphMatching(g1);
gm.addNodeLambda("16x16", [](const NodePtr& node) {
const auto op =
std::static_pointer_cast<OperatorTensor>(node->getOperator());
const auto dims = op->getOutput(0)->dims();
return (dims.end()[-2] == 16 && dims.end()[-1] == 16);
});
const auto results = gm.match("MatMul[16x16]");
REQUIRE(results.size() == 25);
// Check result
fmt::println("Schedule forward tiled graph");
s1 = SequentialScheduler(g1);
s1.resetScheduling();
s1.forward();
const auto tOutTiled = std::static_pointer_cast<OperatorTensor>(g1->getOrderedOutputs()[0].first->getOperator())->getOutput(0)->clone();
REQUIRE(approxEq<float>(tOut, tOutTiled));
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment