diff --git a/include/aidge/backend/cpu/operator/MatMulImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/MatMulImpl_forward_kernels.hpp index a63e9e124596fa529b53068fb6a589ebf42f5f55..9eae167de0ad02eef79510d0e18b6e43ce660d74 100644 --- a/include/aidge/backend/cpu/operator/MatMulImpl_forward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/MatMulImpl_forward_kernels.hpp @@ -15,7 +15,6 @@ #include "aidge/utils/Registrar.hpp" #include <algorithm> -// #include <omp.h> #include "aidge/backend/cpu/operator/MatMulImpl.hpp" namespace Aidge { @@ -23,55 +22,43 @@ namespace Aidge { template <class I, class O> void MatMulImpl_cpu_forward_kernel(const std::vector<DimSize_t>& input1Dims,const std::vector<DimSize_t>& input2Dims, const void* input1_, const void* input2_, void* output_) { - // FIXME: missing MatMul parameters as arguments const I* input1 = static_cast<const I*>(input1_); const I* input2 = static_cast<const I*>(input2_); O* output = static_cast<O*>(output_); - size_t secondToLastIdx2 = input2Dims.size() > 1 ? input2Dims.size() - 2 : 0; + assert((input1Dims.size()>1 && input2Dims.size()>1) && "Inputs must be at least 2D for MatMul"); // Checking if matrix dimensions are compatible for multiplication - assert(input1Dims[input1Dims.size()-1] == input2Dims[secondToLastIdx2] && + assert(input1Dims[input1Dims.size()-1] == input2Dims[input2Dims.size() - 2] && "Matrix dimensions are not compatible for multiplication"); std::size_t innerMulAxis = input1Dims[input1Dims.size()-1]; std::size_t rows1 = input1Dims[input1Dims.size()-2]; std::size_t cols2 = input2Dims[input2Dims.size()-1]; + // Compute number of 2D matrices in input1 and input2 std::size_t nbMat1 = 1, nbMat2 = 1; - if (input1Dims.size()>2) - { - for (std::size_t i = 0; i < input1Dims.size()-2; i++) - { - nbMat1 *= input1Dims[i]; - } - + for (std::size_t i = 0; i < input1Dims.size()-2; i++) { + nbMat1 *= input1Dims[i]; } - if (input2Dims.size()>2) - { - for (std::size_t i = 0; i < input2Dims.size()-2; i++) - { - nbMat2 *= input2Dims[i]; - } - + for (std::size_t i = 0; i < input2Dims.size()-2; i++) { + nbMat2 *= input2Dims[i]; } + std::size_t mat1Size = rows1 * innerMulAxis; std::size_t mat2Size = innerMulAxis * cols2; std::size_t matSize = rows1 * cols2; - std::size_t nbMat = nbMat1 > nbMat2 ? nbMat1 : nbMat2; + std::size_t nbMat = std::max(nbMat1, nbMat2); for (std::size_t i = 0; i < nbMat; i++) { -// #pragma omp parallel for num_threads(8) - for (std::size_t m = 0; m < rows1; m++) - { - - for (size_t k = 0; k < innerMulAxis; k++) - { - for (std::size_t n = 0; n < cols2; n++) - { + for (std::size_t m = 0; m < rows1; m++) { + for (size_t k = 0; k < innerMulAxis; k++) { + const std::size_t input1Idx = (i % nbMat1) * mat1Size + m * innerMulAxis + k; + for (std::size_t n = 0; n < cols2; n++) { + const std::size_t outputIdx = i * matSize + m * cols2 + n; + const std::size_t input2Idx = (i % nbMat2) * mat2Size + k * cols2 + n; if (k==0) { - output[i * matSize + m * cols2 + n] = 0; + output[outputIdx] = 0; } - - output[i * matSize + m * cols2 + n] += input1[(i%nbMat1) * mat1Size + m *innerMulAxis + k] * input2[(i%nbMat2)*mat2Size + k * cols2 + n]; + output[outputIdx] += input1[input1Idx] * input2[input2Idx]; } } }