Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
MatMulTiling.cpp 4.54 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/
#include <cassert>
#include <memory>
#include <set>
#include <string>

#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/Slice.hpp"
#include "aidge/operator/Identity.hpp"
#include "aidge/operator/Concat.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"

// see https://en.wikipedia.org/wiki/Matrix_multiplication_algorithm
void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims) {
    if (matMul->getOperator()->type() != "MatMul") {
        AIDGE_INTERNAL_ASSERT("Operator should be a MatMul.");
    }
    AIDGE_ASSERT(matMul->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
    const auto& op = std::static_pointer_cast<OperatorTensor>(matMul->getOperator());
    if (!op->dimsForwarded()) {
        AIDGE_INTERNAL_ASSERT("Dimensions must be forwarded before any tiling");
    }
    
    const auto& in0Tensor = op->getInput(0);
    const auto& in1Tensor = op->getInput(1);
    const auto& outTensor = op->getOutput(0);
    const auto& input0Dims = in0Tensor->dims();
    const auto& input1Dims = in1Tensor->dims();
    const auto& outputDims = outTensor->dims();
    const auto& outputMatDims = std::vector<std::size_t>(outputDims.end() - 2, outputDims.end());;

    if (outputMatDims[0] > maxDims[0]) {
        const size_t axis = 0;
        const auto splitIndex = outputMatDims[axis] / 2;

        auto identity0 = Identity();
        auto slice00 = Slice();
        auto slice00_starts = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{0, 0}}), "", true);
        slice00_starts->addChild(slice00, 0, 1);
        auto slice00_ends = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{splitIndex, input0Dims[1]}}), "", true);
        slice00_ends->addChild(slice00, 0, 2);
        auto slice00_axes = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{0, 1}}), "", true);
        slice00_axes->addChild(slice00, 0, 3);
        auto slice00_steps = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{1, 1}}), "", true);
        slice00_steps->addChild(slice00, 0, 4);
        auto matMul00 = MatMul();
        auto identity1 = Identity();
        auto slice01 = Slice();
        auto slice01_starts = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{splitIndex, 0}}), "", true);
        slice01_starts->addChild(slice01, 0, 1);
        auto slice01_ends = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{input0Dims[0], input0Dims[1]}}), "", true);
        slice01_ends->addChild(slice01, 0, 2);
        auto slice01_axes = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{0, 1}}), "", true);
        slice01_axes->addChild(slice01, 0, 3);
        auto slice01_steps = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{1, 1}}), "", true);
        slice01_steps->addChild(slice01, 0, 4);
        auto matMul01 = MatMul();
        auto concat0 = Concat(2, axis);

        identity0->addChild(slice00, 0, 0);
        identity0->addChild(slice01, 0, 0);
        identity1->addChild(matMul00, 0, 1);
        identity1->addChild(matMul01, 0, 1);
        slice00->addChild(matMul00, 0, 0);
        slice01->addChild(matMul01, 0, 0);
        matMul00->addChild(concat0, 0, 0);
        matMul01->addChild(concat0, 0, 1);

        auto gMatMul = std::make_shared<GraphView>();
        gMatMul->add({matMul});

        auto g = std::make_shared<GraphView>();
        g->add({identity0, identity1});
        g->add({slice00, slice00_starts, slice00_ends, slice00_axes, slice00_steps, matMul00, matMul01, slice01, slice01_starts, slice01_ends, slice01_axes, slice01_steps, concat0});
        g->save("micrograph");

        auto replaced = GraphView::replace(gMatMul, g);

        if (replaced) {
            g->forwardDims({}, true);

            // Recursive tiling
            matMulTiling(matMul00, maxDims);
            matMulTiling(matMul01, maxDims);
        }
        else {
            Log::warn("Unable to split MatMul {}", matMul->name());
        }
    }
    else if (outputMatDims[1] > maxDims[1]) {

    }
}