Skip to content
Snippets Groups Projects

Matmul rework

Merged Houssem ROUIS requested to merge (removed):matmul_rework into dev
3 files
+ 105
53
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -15,6 +15,7 @@
#include "aidge/utils/Registrar.hpp"
#include <algorithm>
// #include <omp.h>
#include "aidge/backend/cpu/operator/MatMulImpl.hpp"
namespace Aidge {
@@ -26,35 +27,55 @@ void MatMulImpl_cpu_forward_kernel(const std::vector<DimSize_t>& input1Dims,cons
const I* input1 = static_cast<const I*>(input1_);
const I* input2 = static_cast<const I*>(input2_);
O* output = static_cast<O*>(output_);
size_t secondToLastIdx1 = input1Dims.size() > 1 ? input1Dims.size() - 2 : 0;
size_t secondToLastIdx2 = input2Dims.size() > 1 ? input2Dims.size() - 2 : 0;
// Checking if matrix dimensions are compatible for multiplication
assert(input1Dims.back() == input2Dims[secondToLastIdx2] && "Matrix dimensions are not compatible for multiplication");
// Extracting dimensions
size_t rows1 = 1, cols1 = 1, cols2 = 1;
// For input1
for (size_t i = 0; i < input1Dims.size() - 1; ++i) {
rows1 *= input1Dims[i];
}
cols1 = input1Dims.back();
assert(input1Dims[input1Dims.size()-1] == input2Dims[secondToLastIdx2] &&
"Matrix dimensions are not compatible for multiplication");
// For input2
for (size_t i = 1; i < input2Dims.size(); ++i) {
cols2 *= input2Dims[i];
}
std::size_t innerMulAxis = input1Dims[input1Dims.size()-1];
std::size_t rows1 = input1Dims[input1Dims.size()-2];
std::size_t cols2 = input2Dims[input2Dims.size()-1];
std::size_t nbMat1 = 1, nbMat2 = 1;
if (input1Dims.size()>2)
{
for (std::size_t i = 0; i < input1Dims.size()-2; i++)
{
nbMat1 *= input1Dims[i];
}
}
if (input2Dims.size()>2)
{
for (std::size_t i = 0; i < input2Dims.size()-2; i++)
{
nbMat2 *= input2Dims[i];
}
}
std::size_t mat1Size = rows1 * innerMulAxis;
std::size_t mat2Size = innerMulAxis * cols2;
std::size_t matSize = rows1 * cols2;
std::size_t nbMat = nbMat1 > nbMat2 ? nbMat1 : nbMat2;
// Multiplication
for (size_t i = 0; i < rows1; ++i) {
for (size_t j = 0; j < cols2; ++j) {
float sum = 0.0;
for (size_t k = 0; k < cols1; ++k) {
sum += input1[i * cols1 + k] * input2[k * cols2 + j];
}
output[i * cols2 + j] = sum;
}
}
for (std::size_t i = 0; i < nbMat; i++) {
// #pragma omp parallel for num_threads(8)
for (std::size_t m = 0; m < rows1; m++)
{
for (size_t k = 0; k < innerMulAxis; k++)
{
for (std::size_t n = 0; n < cols2; n++)
{
if (k==0) {
output[i * matSize + m * cols2 + n] = 0;
}
output[i * matSize + m * cols2 + n] += input1[(i%nbMat1) * mat1Size + m *innerMulAxis + k] * input2[(i%nbMat2)*mat2Size + k * cols2 + n];
}
}
}
}
}
namespace {
Loading