Skip to content
Snippets Groups Projects
Commit ecb77eed authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'matmultiling' into 'dev'

Add MatMulTiling recipe

See merge request !244
parents bce5965b 39dbad54
No related branches found
No related tags found
No related merge requests found
...@@ -119,30 +119,17 @@ private: ...@@ -119,30 +119,17 @@ private:
template <typename T> template <typename T>
const std::string TensorImpl_cpu<T>::Backend = "cpu"; const std::string TensorImpl_cpu<T>::Backend = "cpu";
namespace { REGISTRAR(Tensor, {"cpu", DataType::Float64}, Aidge::TensorImpl_cpu<double>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Float64( REGISTRAR(Tensor, {"cpu", DataType::Float32}, Aidge::TensorImpl_cpu<float>::create);
{"cpu", DataType::Float64}, Aidge::TensorImpl_cpu<double>::create); REGISTRAR(Tensor, {"cpu", DataType::Float16}, Aidge::TensorImpl_cpu<half_float::half>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Float32( REGISTRAR(Tensor, {"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<int64_t>::create);
{"cpu", DataType::Float32}, Aidge::TensorImpl_cpu<float>::create); REGISTRAR(Tensor, {"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int32_t>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Float16( REGISTRAR(Tensor, {"cpu", DataType::Int16}, Aidge::TensorImpl_cpu<int16_t>::create);
{"cpu", DataType::Float16}, Aidge::TensorImpl_cpu<half_float::half>::create); REGISTRAR(Tensor, {"cpu", DataType::Int8}, Aidge::TensorImpl_cpu<int8_t>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Int64( REGISTRAR(Tensor, {"cpu", DataType::UInt64}, Aidge::TensorImpl_cpu<uint64_t>::create);
{"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<int64_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt32}, Aidge::TensorImpl_cpu<uint32_t>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Int32( REGISTRAR(Tensor, {"cpu", DataType::UInt16}, Aidge::TensorImpl_cpu<uint16_t>::create);
{"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int32_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt8}, Aidge::TensorImpl_cpu<uint8_t>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Int16(
{"cpu", DataType::Int16}, Aidge::TensorImpl_cpu<int16_t>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Int8(
{"cpu", DataType::Int8}, Aidge::TensorImpl_cpu<int8_t>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_UInt64(
{"cpu", DataType::UInt64}, Aidge::TensorImpl_cpu<uint64_t>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_UInt32(
{"cpu", DataType::UInt32}, Aidge::TensorImpl_cpu<uint32_t>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_UInt16(
{"cpu", DataType::UInt16}, Aidge::TensorImpl_cpu<uint16_t>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_UInt8(
{"cpu", DataType::UInt8}, Aidge::TensorImpl_cpu<uint8_t>::create);
} // namespace
} // namespace Aidge } // namespace Aidge
#endif /* AIDGE_CPU_DATA_TENSORIMPL_H_ */ #endif /* AIDGE_CPU_DATA_TENSORIMPL_H_ */
...@@ -134,6 +134,23 @@ void explicitTranspose(std::shared_ptr<GraphView> graphView); ...@@ -134,6 +134,23 @@ void explicitTranspose(std::shared_ptr<GraphView> graphView);
*/ */
void expandMetaOps(std::shared_ptr<GraphView> graph, bool recursive = false); void expandMetaOps(std::shared_ptr<GraphView> graph, bool recursive = false);
/**
* @brief Tile any :cpp:function:`Aidge::MatMul` operator to several fixed size matrix multiplications.
* For instance, for a MatMul of size 80x80 and a tiling of 16x16, this will tile
* the MatMul operator to 25 (5 by 5) MatMul operators of size 16x16, with Slice
* operators inserted at the inputs and Concat operators inserted at the outputs.
*
* This is especially useful when matrix multiplication must be mapped to fixed
* maximum size hardware TPU (Tensor Processing Unit) or MMA (Matrix Multiplication
* Accelerator). This recipe can be combined with the :cpp:function:`Aidge::convToMatMul` recipe in
* order to convert convolutions to matrix multiplication beforehand, and
* :cpp:function:`Aidge::constantFolding` recipe to fold sliced constant tensors.
*
* @param matMul MatMul operator to be tiled.
* @param maxDims Maximum output dimensions of the tiled MatMul operators.
*/
void matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims);
/** /**
* Fuse each sub-graph matching a query in a Meta Operator. * Fuse each sub-graph matching a query in a Meta Operator.
* @param graph Graph to manipulate * @param graph Graph to manipulate
......
...@@ -49,7 +49,9 @@ std::shared_ptr<Aidge::Operator> Aidge::Concat_Op::clone() const { ...@@ -49,7 +49,9 @@ std::shared_ptr<Aidge::Operator> Aidge::Concat_Op::clone() const {
void Aidge::Concat_OpImpl::forward() { void Aidge::Concat_OpImpl::forward() {
const Concat_Op& op = dynamic_cast<const Concat_Op&>(mOp); const Concat_Op& op = dynamic_cast<const Concat_Op&>(mOp);
const DimSize_t axis = op.axis(); auto axis = op.axis();
const auto nbDimsInput0 = op.getInput(0)->nbDims();
axis = (axis < 0) ? axis + static_cast<std::int32_t>(nbDimsInput0) : axis;
assert(op.getInput(0) && "missing input in Concat operator"); assert(op.getInput(0) && "missing input in Concat operator");
for (IOIndex_t i = 1; i < mOp.nbInputs(); ++i) { for (IOIndex_t i = 1; i < mOp.nbInputs(); ++i) {
......
...@@ -71,7 +71,7 @@ bool Aidge::MatMul_Op::forwardDims(bool /*allowDataDependency*/) { ...@@ -71,7 +71,7 @@ bool Aidge::MatMul_Op::forwardDims(bool /*allowDataDependency*/) {
std::vector<std::size_t> outDims = std::vector<std::size_t>(dims_size-2, 1); std::vector<std::size_t> outDims = std::vector<std::size_t>(dims_size-2, 1);
for (std::size_t i = 0; i < dims_size-2; ++i) { for (std::size_t i = 0; i < dims_size-2; ++i) {
AIDGE_ASSERT((dims0[i] == dims1[i]) || (dims0[i] == 1) || (dims1[i] == 1), "Bad vector dimension."); AIDGE_ASSERT((dims0[i] == dims1[i]) || (dims0[i] == 1) || (dims1[i] == 1), "Bad dimension {}: {} != {} for input #0 {} and #1 {}.", i, dims0[i], dims1[i], dims0, dims1);
outDims[i] = std::max(dims0[i], dims1[i]); outDims[i] = std::max(dims0[i], dims1[i]);
} }
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <memory>
#include <set>
#include <string>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/Slice.hpp"
#include "aidge/operator/Identity.hpp"
#include "aidge/operator/Concat.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
// see https://en.wikipedia.org/wiki/Matrix_multiplication_algorithm
void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims) {
if (matMul->getOperator()->type() != "MatMul") {
AIDGE_INTERNAL_ASSERT("Operator should be a MatMul.");
}
AIDGE_ASSERT(matMul->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
const auto& op = std::static_pointer_cast<OperatorTensor>(matMul->getOperator());
if (!op->dimsForwarded()) {
AIDGE_INTERNAL_ASSERT("Dimensions must be forwarded before any tiling");
}
const auto& in0Tensor = op->getInput(0);
const auto& in1Tensor = op->getInput(1);
const auto& outTensor = op->getOutput(0);
const auto& input0Dims = in0Tensor->dims();
const auto& input1Dims = in1Tensor->dims();
const auto& outputDims = outTensor->dims();
const auto& outputMatDims = std::vector<std::size_t>(outputDims.end() - 2, outputDims.end());;
if (outputMatDims[0] > maxDims[0] || outputMatDims[1] > maxDims[1]) {
const auto sliceDims = (outputMatDims[0] > maxDims[0]) ? input0Dims : input1Dims;
std::int32_t axis;
std::int64_t splitIndex0_end = static_cast<std::int64_t>(sliceDims.end()[-2]);
std::int64_t splitIndex0_start = 0;
std::int64_t splitIndex1_end = static_cast<std::int64_t>(sliceDims.end()[-1]);
std::int64_t splitIndex1_start = 0;
if (outputMatDims[0] > maxDims[0]) {
splitIndex0_end = maxDims[0];
splitIndex0_start = maxDims[0];
axis = -2;
}
else {
splitIndex1_end = maxDims[1];
splitIndex1_start = maxDims[1];
axis = -1;
}
auto identity0 = Identity();
auto sliceX0 = Slice();
auto sliceX0_starts = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{0, 0}}), "", true);
sliceX0_starts->addChild(sliceX0, 0, 1);
auto sliceX0_ends = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{splitIndex0_end, splitIndex1_end}}), "", true);
sliceX0_ends->addChild(sliceX0, 0, 2);
auto sliceX0_axes = Producer(std::make_shared<Tensor>(Vector<std::int8_t>{{-2, -1}}), "", true);
sliceX0_axes->addChild(sliceX0, 0, 3);
auto sliceX0_steps = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{1, 1}}), "", true);
sliceX0_steps->addChild(sliceX0, 0, 4);
auto matMulX0 = MatMul();
auto identity1 = Identity();
auto sliceX1 = Slice();
auto sliceX1_starts = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{splitIndex0_start, splitIndex1_start}}), "", true);
sliceX1_starts->addChild(sliceX1, 0, 1);
auto sliceX1_ends = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{static_cast<std::int64_t>(sliceDims.end()[-2]), static_cast<std::int64_t>(sliceDims.end()[-1])}}), "", true);
sliceX1_ends->addChild(sliceX1, 0, 2);
auto sliceX1_axes = Producer(std::make_shared<Tensor>(Vector<std::int8_t>{{-2, -1}}), "", true);
sliceX1_axes->addChild(sliceX1, 0, 3);
auto sliceX1_steps = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{1, 1}}), "", true);
sliceX1_steps->addChild(sliceX1, 0, 4);
auto matMulX1 = MatMul();
auto concat = Concat(2, axis);
if (outputMatDims[0] > maxDims[0]) {
identity0->addChild(sliceX0, 0, 0);
identity0->addChild(sliceX1, 0, 0);
identity1->addChild(matMulX0, 0, 1);
identity1->addChild(matMulX1, 0, 1);
sliceX0->addChild(matMulX0, 0, 0);
sliceX1->addChild(matMulX1, 0, 0);
}
else {
identity0->addChild(matMulX0, 0, 0);
identity0->addChild(matMulX1, 0, 0);
identity1->addChild(sliceX0, 0, 0);
identity1->addChild(sliceX1, 0, 0);
sliceX0->addChild(matMulX0, 0, 1);
sliceX1->addChild(matMulX1, 0, 1);
}
matMulX0->addChild(concat, 0, 0);
matMulX1->addChild(concat, 0, 1);
auto gMatMul = std::make_shared<GraphView>();
gMatMul->add({matMul});
auto g = std::make_shared<GraphView>();
g->add({identity0});
g->add({identity1});
g->add({sliceX0, sliceX0_starts, sliceX0_ends, sliceX0_axes, sliceX0_steps, matMulX0, matMulX1, sliceX1, sliceX1_starts, sliceX1_ends, sliceX1_axes, sliceX1_steps, concat});
auto replaced = GraphView::replace(gMatMul, g);
if (replaced) {
g->forwardDims({}, true);
// Recursive tiling
matMulTiling(matMulX1, maxDims);
matMulTiling(matMulX0, maxDims);
}
else {
Log::warn("Unable to split MatMul {}", matMul->name());
}
}
}
...@@ -13,24 +13,15 @@ ...@@ -13,24 +13,15 @@
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Matching.hpp"
#include "aidge/recipes/Recipes.hpp" #include "aidge/recipes/Recipes.hpp"
//Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp"
size_t Aidge::removeNode(std::shared_ptr<GraphView> graphView, const std::string& type, bool incProducers) { size_t Aidge::removeNode(std::shared_ptr<GraphView> graphView, const std::string& type, bool incProducers) {
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); auto matches = SinglePassGraphMatching(graphView).match(type);
regex->setNodeKey(type, "getType($) =='" + type + "'"); for (const auto& match : matches) {
regex->addQuery(type + "#"); std::set<NodePtr> nodesToRemove = {match.graph->rootNode()};
const auto matches = regex->match(graphView);
for (const auto& solution : matches) {
assert(solution->at(type).size() == 1 && "Wrong number of nodes to replace\n");
std::set<NodePtr> nodesToRemove = solution->at(type);
if (incProducers) { if (incProducers) {
for (const auto& nodePtr: (*solution->at(type).begin())->getParents()) { for (const auto& nodePtr: match.graph->rootNode()->getParents()) {
if (nodePtr != nullptr && nodePtr->type() == "Producer") { if (nodePtr != nullptr && nodePtr->type() == "Producer") {
nodesToRemove.insert(nodePtr); nodesToRemove.insert(nodePtr);
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment