Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
MatMulImpl.cpp 5.38 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <cstddef>  // std::size_t
#include <cstdint>  // std::int32_t
#include <numeric>  // std::accumulate
#include <vector>

#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/operator/MatMul.hpp"
#include "aidge/utils/Types.h"

#include "aidge/backend/cpu/operator/MatMulImpl.hpp"
#include "aidge/backend/cpu/operator/MatMulImpl_forward_kernels.hpp"

void Aidge::MatMulImpl_cpu::forward()
{
    assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
    assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(1)) && "missing input #1");

    // Find the correct kernel type
    auto kernelFunc = Registrar<MatMulImplForward_cpu>::create(
        {std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
         std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});

    // Compute compatible input dimensions
    std::vector<std::size_t> dims0 = static_cast<const MatMul_Op&>(mOp).getInput(0)->dims();
    std::vector<std::size_t> dims1 = static_cast<const MatMul_Op&>(mOp).getInput(1)->dims();

    // keep second-to-last dimension of dims0
    const std::size_t keepDim0 = (dims0.size() > 1) ? 1 : 0;
    // keep last dimension of dims1
    const std::size_t keepDim1 = (dims1.size() > 1) ? 1 : 0;

    if (dims0.size() == 1) {
        dims0.insert(dims0.cbegin(), 1);
    }
    if (dims1.size() == 1) {
        dims1.push_back(1);
    }

    if (dims0.size() > dims1.size()) {
        dims1.insert(dims1.cbegin(), dims0.size() - dims1.size(), std::size_t(1));
    }
    else if (dims1.size() > dims0.size()) {
        dims0.insert(dims0.cbegin(), dims1.size() - dims0.size(), std::size_t(1));
    }

    // const std::size_t dims_size = std::max(dims0.size(), dims1.size());
    // at this point, dims0.size() == dims1.size()
    const std::size_t nbDims = dims0.size();

    // initialize strides to iterate through data because of broadcasting
    std::unique_ptr<std::size_t[]> stride_post0 = std::make_unique<std::size_t[]>(nbDims - 2);
    std::unique_ptr<std::size_t[]> stride_post1 = std::make_unique<std::size_t[]>(nbDims - 2);
    std::unique_ptr<std::int32_t[]> stride_step0 = std::make_unique<std::int32_t[]>(nbDims - 2);
    std::unique_ptr<std::int32_t[]> stride_step1 = std::make_unique<std::int32_t[]>(nbDims - 2);
    if (nbDims > 2) {
        stride_post0[nbDims - 3] = 1;
        stride_post1[nbDims - 3] = 1;
        for (std::size_t i = nbDims-4; i != static_cast<std::size_t>(-1); --i) {
            stride_post0[i] = stride_post0[i+1]*dims0[i+1];
            stride_post1[i] = stride_post1[i+1]*dims1[i+1];
        }
        for (std::size_t i = 0; i != nbDims-2; ++i) {
            stride_step0[i] = (dims0[i] == 1) ? 1 - static_cast<std::int32_t>(stride_post0[i]) : 1;
            stride_step1[i] = (dims1[i] == 1) ? 1 - static_cast<std::int32_t>(stride_post1[i]) : 1;
        }
    }

    const std::vector<std::size_t>& outDims = static_cast<const MatMul_Op&>(mOp).getOutput(0)->dims();
    const std::size_t nbMatrices = std::accumulate(outDims.cbegin(), outDims.cend() - keepDim0 - keepDim1, 1, std::multiplies<std::size_t>());
    std::size_t dim = outDims.size() - 1 - keepDim0 - keepDim1;

    // variables for arrays offsets
    std::size_t offsetIn0 = 0;
    std::size_t offsetIn1 = 0;
    std::size_t offsetOut = 0;
    const std::size_t n = dims0[nbDims - 2];
    const std::size_t k = dims0[nbDims - 1];
    const std::size_t m = dims1[nbDims - 1];
    const std::size_t matrix0Size = n*k;
    const std::size_t matrix1Size = k*m;
    const std::size_t matrixOutSize = n*m;
    for (std::size_t stack = 0; stack < nbMatrices;) {
        kernelFunc(n, k, m,
                    getCPUPtr(mOp.getRawInput(0), offsetIn0*matrix0Size),
                    getCPUPtr(mOp.getRawInput(1), offsetIn1*matrix1Size),
                    getCPUPtr(mOp.getRawOutput(0), offsetOut*matrixOutSize));
        if (++stack < nbMatrices) {
            std::size_t tmp_stack = stack;
            while(tmp_stack % outDims[dim] == 0) {
                tmp_stack /= outDims[dim];
                dim--;
            }
            offsetIn0 += stride_step0[dim];
            offsetIn1 += stride_step1[dim];
            ++offsetOut;
            dim = outDims.size() - 1 - keepDim0 - keepDim1;
        }
    }
}

// void Aidge::MatMulImpl_cpu::forward()
// {
//     assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
//     assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(1)) && "missing input #1");

//     // Find the correct kernel type
//     auto kernelFunc = Registrar<MatMulImplForward_cpu>::create(
//         {std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
//          std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});

//     kernelFunc(
//         std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(),
//         std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims(),
//         getCPUPtr(mOp.getRawInput(0)),
//         getCPUPtr(mOp.getRawInput(1)),
//         getCPUPtr(mOp.getRawOutput(0)));
// }