Skip to content
Snippets Groups Projects
Commit 5f570613 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

cleanup MatMul forward kernel

parent 204f80b0
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!34Matmul rework
......@@ -15,7 +15,6 @@
#include "aidge/utils/Registrar.hpp"
#include <algorithm>
// #include <omp.h>
#include "aidge/backend/cpu/operator/MatMulImpl.hpp"
namespace Aidge {
......@@ -23,55 +22,43 @@ namespace Aidge {
template <class I, class O>
void MatMulImpl_cpu_forward_kernel(const std::vector<DimSize_t>& input1Dims,const std::vector<DimSize_t>& input2Dims,
const void* input1_, const void* input2_, void* output_) {
// FIXME: missing MatMul parameters as arguments
const I* input1 = static_cast<const I*>(input1_);
const I* input2 = static_cast<const I*>(input2_);
O* output = static_cast<O*>(output_);
size_t secondToLastIdx2 = input2Dims.size() > 1 ? input2Dims.size() - 2 : 0;
assert((input1Dims.size()>1 && input2Dims.size()>1) && "Inputs must be at least 2D for MatMul");
// Checking if matrix dimensions are compatible for multiplication
assert(input1Dims[input1Dims.size()-1] == input2Dims[secondToLastIdx2] &&
assert(input1Dims[input1Dims.size()-1] == input2Dims[input2Dims.size() - 2] &&
"Matrix dimensions are not compatible for multiplication");
std::size_t innerMulAxis = input1Dims[input1Dims.size()-1];
std::size_t rows1 = input1Dims[input1Dims.size()-2];
std::size_t cols2 = input2Dims[input2Dims.size()-1];
// Compute number of 2D matrices in input1 and input2
std::size_t nbMat1 = 1, nbMat2 = 1;
if (input1Dims.size()>2)
{
for (std::size_t i = 0; i < input1Dims.size()-2; i++)
{
nbMat1 *= input1Dims[i];
}
for (std::size_t i = 0; i < input1Dims.size()-2; i++) {
nbMat1 *= input1Dims[i];
}
if (input2Dims.size()>2)
{
for (std::size_t i = 0; i < input2Dims.size()-2; i++)
{
nbMat2 *= input2Dims[i];
}
for (std::size_t i = 0; i < input2Dims.size()-2; i++) {
nbMat2 *= input2Dims[i];
}
std::size_t mat1Size = rows1 * innerMulAxis;
std::size_t mat2Size = innerMulAxis * cols2;
std::size_t matSize = rows1 * cols2;
std::size_t nbMat = nbMat1 > nbMat2 ? nbMat1 : nbMat2;
std::size_t nbMat = std::max(nbMat1, nbMat2);
for (std::size_t i = 0; i < nbMat; i++) {
// #pragma omp parallel for num_threads(8)
for (std::size_t m = 0; m < rows1; m++)
{
for (size_t k = 0; k < innerMulAxis; k++)
{
for (std::size_t n = 0; n < cols2; n++)
{
for (std::size_t m = 0; m < rows1; m++) {
for (size_t k = 0; k < innerMulAxis; k++) {
const std::size_t input1Idx = (i % nbMat1) * mat1Size + m * innerMulAxis + k;
for (std::size_t n = 0; n < cols2; n++) {
const std::size_t outputIdx = i * matSize + m * cols2 + n;
const std::size_t input2Idx = (i % nbMat2) * mat2Size + k * cols2 + n;
if (k==0) {
output[i * matSize + m * cols2 + n] = 0;
output[outputIdx] = 0;
}
output[i * matSize + m * cols2 + n] += input1[(i%nbMat1) * mat1Size + m *innerMulAxis + k] * input2[(i%nbMat2)*mat2Size + k * cols2 + n];
output[outputIdx] += input1[input1Idx] * input2[input2Idx];
}
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment