Skip to content
Snippets Groups Projects
Commit 5f570613 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

cleanup MatMul forward kernel

parent 204f80b0
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!34Matmul rework
This commit is part of merge request !34. Comments created here will be created in the context of that merge request.
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include <algorithm> #include <algorithm>
// #include <omp.h>
#include "aidge/backend/cpu/operator/MatMulImpl.hpp" #include "aidge/backend/cpu/operator/MatMulImpl.hpp"
namespace Aidge { namespace Aidge {
...@@ -23,55 +22,43 @@ namespace Aidge { ...@@ -23,55 +22,43 @@ namespace Aidge {
template <class I, class O> template <class I, class O>
void MatMulImpl_cpu_forward_kernel(const std::vector<DimSize_t>& input1Dims,const std::vector<DimSize_t>& input2Dims, void MatMulImpl_cpu_forward_kernel(const std::vector<DimSize_t>& input1Dims,const std::vector<DimSize_t>& input2Dims,
const void* input1_, const void* input2_, void* output_) { const void* input1_, const void* input2_, void* output_) {
// FIXME: missing MatMul parameters as arguments
const I* input1 = static_cast<const I*>(input1_); const I* input1 = static_cast<const I*>(input1_);
const I* input2 = static_cast<const I*>(input2_); const I* input2 = static_cast<const I*>(input2_);
O* output = static_cast<O*>(output_); O* output = static_cast<O*>(output_);
size_t secondToLastIdx2 = input2Dims.size() > 1 ? input2Dims.size() - 2 : 0; assert((input1Dims.size()>1 && input2Dims.size()>1) && "Inputs must be at least 2D for MatMul");
// Checking if matrix dimensions are compatible for multiplication // Checking if matrix dimensions are compatible for multiplication
assert(input1Dims[input1Dims.size()-1] == input2Dims[secondToLastIdx2] && assert(input1Dims[input1Dims.size()-1] == input2Dims[input2Dims.size() - 2] &&
"Matrix dimensions are not compatible for multiplication"); "Matrix dimensions are not compatible for multiplication");
std::size_t innerMulAxis = input1Dims[input1Dims.size()-1]; std::size_t innerMulAxis = input1Dims[input1Dims.size()-1];
std::size_t rows1 = input1Dims[input1Dims.size()-2]; std::size_t rows1 = input1Dims[input1Dims.size()-2];
std::size_t cols2 = input2Dims[input2Dims.size()-1]; std::size_t cols2 = input2Dims[input2Dims.size()-1];
// Compute number of 2D matrices in input1 and input2
std::size_t nbMat1 = 1, nbMat2 = 1; std::size_t nbMat1 = 1, nbMat2 = 1;
if (input1Dims.size()>2) for (std::size_t i = 0; i < input1Dims.size()-2; i++) {
{ nbMat1 *= input1Dims[i];
for (std::size_t i = 0; i < input1Dims.size()-2; i++)
{
nbMat1 *= input1Dims[i];
}
} }
if (input2Dims.size()>2) for (std::size_t i = 0; i < input2Dims.size()-2; i++) {
{ nbMat2 *= input2Dims[i];
for (std::size_t i = 0; i < input2Dims.size()-2; i++)
{
nbMat2 *= input2Dims[i];
}
} }
std::size_t mat1Size = rows1 * innerMulAxis; std::size_t mat1Size = rows1 * innerMulAxis;
std::size_t mat2Size = innerMulAxis * cols2; std::size_t mat2Size = innerMulAxis * cols2;
std::size_t matSize = rows1 * cols2; std::size_t matSize = rows1 * cols2;
std::size_t nbMat = nbMat1 > nbMat2 ? nbMat1 : nbMat2; std::size_t nbMat = std::max(nbMat1, nbMat2);
for (std::size_t i = 0; i < nbMat; i++) { for (std::size_t i = 0; i < nbMat; i++) {
// #pragma omp parallel for num_threads(8) for (std::size_t m = 0; m < rows1; m++) {
for (std::size_t m = 0; m < rows1; m++) for (size_t k = 0; k < innerMulAxis; k++) {
{ const std::size_t input1Idx = (i % nbMat1) * mat1Size + m * innerMulAxis + k;
for (std::size_t n = 0; n < cols2; n++) {
for (size_t k = 0; k < innerMulAxis; k++) const std::size_t outputIdx = i * matSize + m * cols2 + n;
{ const std::size_t input2Idx = (i % nbMat2) * mat2Size + k * cols2 + n;
for (std::size_t n = 0; n < cols2; n++)
{
if (k==0) { if (k==0) {
output[i * matSize + m * cols2 + n] = 0; output[outputIdx] = 0;
} }
output[outputIdx] += input1[input1Idx] * input2[input2Idx];
output[i * matSize + m * cols2 + n] += input1[(i%nbMat1) * mat1Size + m *innerMulAxis + k] * input2[(i%nbMat2)*mat2Size + k * cols2 + n];
} }
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment