-
Olivier BICHLER authoredOlivier BICHLER authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
MatMulTiling.cpp 4.54 KiB
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <memory>
#include <set>
#include <string>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/Slice.hpp"
#include "aidge/operator/Identity.hpp"
#include "aidge/operator/Concat.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
// see https://en.wikipedia.org/wiki/Matrix_multiplication_algorithm
void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims) {
if (matMul->getOperator()->type() != "MatMul") {
AIDGE_INTERNAL_ASSERT("Operator should be a MatMul.");
}
AIDGE_ASSERT(matMul->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
const auto& op = std::static_pointer_cast<OperatorTensor>(matMul->getOperator());
if (!op->dimsForwarded()) {
AIDGE_INTERNAL_ASSERT("Dimensions must be forwarded before any tiling");
}
const auto& in0Tensor = op->getInput(0);
const auto& in1Tensor = op->getInput(1);
const auto& outTensor = op->getOutput(0);
const auto& input0Dims = in0Tensor->dims();
const auto& input1Dims = in1Tensor->dims();
const auto& outputDims = outTensor->dims();
const auto& outputMatDims = std::vector<std::size_t>(outputDims.end() - 2, outputDims.end());;
if (outputMatDims[0] > maxDims[0]) {
const size_t axis = 0;
const auto splitIndex = outputMatDims[axis] / 2;
auto identity0 = Identity();
auto slice00 = Slice();
auto slice00_starts = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{0, 0}}), "", true);
slice00_starts->addChild(slice00, 0, 1);
auto slice00_ends = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{splitIndex, input0Dims[1]}}), "", true);
slice00_ends->addChild(slice00, 0, 2);
auto slice00_axes = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{0, 1}}), "", true);
slice00_axes->addChild(slice00, 0, 3);
auto slice00_steps = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{1, 1}}), "", true);
slice00_steps->addChild(slice00, 0, 4);
auto matMul00 = MatMul();
auto identity1 = Identity();
auto slice01 = Slice();
auto slice01_starts = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{splitIndex, 0}}), "", true);
slice01_starts->addChild(slice01, 0, 1);
auto slice01_ends = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{input0Dims[0], input0Dims[1]}}), "", true);
slice01_ends->addChild(slice01, 0, 2);
auto slice01_axes = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{0, 1}}), "", true);
slice01_axes->addChild(slice01, 0, 3);
auto slice01_steps = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{1, 1}}), "", true);
slice01_steps->addChild(slice01, 0, 4);
auto matMul01 = MatMul();
auto concat0 = Concat(2, axis);
identity0->addChild(slice00, 0, 0);
identity0->addChild(slice01, 0, 0);
identity1->addChild(matMul00, 0, 1);
identity1->addChild(matMul01, 0, 1);
slice00->addChild(matMul00, 0, 0);
slice01->addChild(matMul01, 0, 0);
matMul00->addChild(concat0, 0, 0);
matMul01->addChild(concat0, 0, 1);
auto gMatMul = std::make_shared<GraphView>();
gMatMul->add({matMul});
auto g = std::make_shared<GraphView>();
g->add({identity0, identity1});
g->add({slice00, slice00_starts, slice00_ends, slice00_axes, slice00_steps, matMul00, matMul01, slice01, slice01_starts, slice01_ends, slice01_axes, slice01_steps, concat0});
g->save("micrograph");
auto replaced = GraphView::replace(gMatMul, g);
if (replaced) {
g->forwardDims({}, true);
// Recursive tiling
matMulTiling(matMul00, maxDims);
matMulTiling(matMul01, maxDims);
}
else {
Log::warn("Unable to split MatMul {}", matMul->name());
}
}
else if (outputMatDims[1] > maxDims[1]) {
}
}