Skip to content
Snippets Groups Projects
Commit a2bcd467 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'matmultiling' into 'dev'

Add MatMulTiling recipe

See merge request eclipse/aidge/aidge_backend_cpu!105
parents d318dbfe e87dd5bd
No related branches found
No related tags found
3 merge requests!118v0.4.0,!108v0.4.0,!105Add MatMulTiling recipe
Pipeline #58696 passed
......@@ -89,13 +89,13 @@ void SliceImpl_cpu_forward_kernel(const std::vector<std::int64_t>& starts,
}
REGISTRAR(SliceImpl_cpu,
{DataType::Float32},
{{DataType::Float32, DataType::Any}, {DataType::Float32}},
{ProdConso::inPlaceModel, Aidge::SliceImpl_cpu_forward_kernel<float, float>, nullptr});
REGISTRAR(SliceImpl_cpu,
{DataType::Float64},
{{DataType::Float64, DataType::Any}, {DataType::Float64}},
{ProdConso::inPlaceModel, Aidge::SliceImpl_cpu_forward_kernel<double, double>, nullptr});
REGISTRAR(SliceImpl_cpu,
{DataType::Int32},
{{DataType::Int32, DataType::Any}, {DataType::Int32}},
{ProdConso::inPlaceModel, Aidge::SliceImpl_cpu_forward_kernel<int32_t, int32_t>, nullptr});
} // namespace Aidge
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cstddef>
#include <random> // std::random_device, std::mt19937, std::uniform_real_distribution
#include <catch2/catch_test_macros.hpp>
#include "aidge/recipes/Recipes.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/AvgPooling.hpp"
#include "aidge/operator/MaxPooling.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/graph/Matching.hpp"
#include "aidge/utils/TensorUtils.hpp"
using namespace Aidge;
TEST_CASE("[MatMulTiling]") {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> valueDist(-1.0f, 1.0f);
auto dataProvider = Producer({2, 3, 80, 80}, "dataProvider");
auto w1 = Producer({2, 3, 80, 80}, "w1");
auto matmul1 = MatMul("matmul1");
auto w2 = Producer({2, 3, 80, 80}, "w1");
auto matmul2 = MatMul("matmul2");
auto w3 = Producer({2, 3, 80, 80}, "w1");
auto matmul3 = MatMul("matmul3");
dataProvider->addChild(matmul1, 0, 0);
w1->addChild(matmul1, 0, 1);
matmul1->addChild(matmul2, 0, 0);
w2->addChild(matmul2, 0, 1);
matmul2->addChild(matmul3, 0, 0);
w3->addChild(matmul3, 0, 1);
auto g1 = getConnectedGraphView(matmul1);
g1->setBackend("cpu");
g1->forwardDims();
g1->save("MatMulSplitting_graph");
// Fill random values
fmt::println("Fill random values");
auto tData = std::static_pointer_cast<OperatorTensor>(dataProvider->getOperator())->getOutput(0);
for (size_t i = 0; i < tData->size(); ++i) {
tData->set<float>(i, valueDist(gen));
}
auto tw1 = std::static_pointer_cast<OperatorTensor>(w1->getOperator())->getOutput(0);
for (size_t i = 0; i < tw1->size(); ++i) {
tw1->set<float>(i, valueDist(gen));
}
auto tw2 = std::static_pointer_cast<OperatorTensor>(w2->getOperator())->getOutput(0);
for (size_t i = 0; i < tw2->size(); ++i) {
tw2->set<float>(i, valueDist(gen));
}
auto tw3 = std::static_pointer_cast<OperatorTensor>(w3->getOperator())->getOutput(0);
for (size_t i = 0; i < tw3->size(); ++i) {
tw3->set<float>(i, valueDist(gen));
}
fmt::println("Schedule forward graph");
auto s1 = SequentialScheduler(g1);
s1.forward();
const auto tOut = std::static_pointer_cast<OperatorTensor>(g1->getOrderedOutputs()[0].first->getOperator())->getOutput(0)->clone();
// Tiling
fmt::println("Tiling");
matMulTiling(matmul1, {16, 16});
removeIdentity(g1);
g1->setBackend("cpu");
g1->save("MatMulSplitting_graph_split");
auto gm = SinglePassGraphMatching(g1);
gm.addNodeLambda("16x16", [](const NodePtr& node) {
const auto op =
std::static_pointer_cast<OperatorTensor>(node->getOperator());
const auto dims = op->getOutput(0)->dims();
return (dims.end()[-2] == 16 && dims.end()[-1] == 16);
});
const auto results = gm.match("MatMul[16x16]");
REQUIRE(results.size() == 25);
// Check result
fmt::println("Schedule forward tiled graph");
s1 = SequentialScheduler(g1);
s1.resetScheduling();
s1.forward();
const auto tOutTiled = std::static_pointer_cast<OperatorTensor>(g1->getOrderedOutputs()[0].first->getOperator())->getOutput(0)->clone();
REQUIRE(approxEq<float>(tOut, tOutTiled));
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment