Skip to content
Snippets Groups Projects
Commit 204f80b0 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

fix matmul to support more matrix shapes

parent 109b5a64
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!34Matmul rework
Pipeline #38568 failed
......@@ -15,6 +15,7 @@
#include "aidge/utils/Registrar.hpp"
#include <algorithm>
// #include <omp.h>
#include "aidge/backend/cpu/operator/MatMulImpl.hpp"
namespace Aidge {
......@@ -26,35 +27,55 @@ void MatMulImpl_cpu_forward_kernel(const std::vector<DimSize_t>& input1Dims,cons
const I* input1 = static_cast<const I*>(input1_);
const I* input2 = static_cast<const I*>(input2_);
O* output = static_cast<O*>(output_);
size_t secondToLastIdx1 = input1Dims.size() > 1 ? input1Dims.size() - 2 : 0;
size_t secondToLastIdx2 = input2Dims.size() > 1 ? input2Dims.size() - 2 : 0;
// Checking if matrix dimensions are compatible for multiplication
assert(input1Dims.back() == input2Dims[secondToLastIdx2] && "Matrix dimensions are not compatible for multiplication");
// Extracting dimensions
size_t rows1 = 1, cols1 = 1, cols2 = 1;
// For input1
for (size_t i = 0; i < input1Dims.size() - 1; ++i) {
rows1 *= input1Dims[i];
}
cols1 = input1Dims.back();
assert(input1Dims[input1Dims.size()-1] == input2Dims[secondToLastIdx2] &&
"Matrix dimensions are not compatible for multiplication");
// For input2
for (size_t i = 1; i < input2Dims.size(); ++i) {
cols2 *= input2Dims[i];
}
std::size_t innerMulAxis = input1Dims[input1Dims.size()-1];
std::size_t rows1 = input1Dims[input1Dims.size()-2];
std::size_t cols2 = input2Dims[input2Dims.size()-1];
std::size_t nbMat1 = 1, nbMat2 = 1;
if (input1Dims.size()>2)
{
for (std::size_t i = 0; i < input1Dims.size()-2; i++)
{
nbMat1 *= input1Dims[i];
}
}
if (input2Dims.size()>2)
{
for (std::size_t i = 0; i < input2Dims.size()-2; i++)
{
nbMat2 *= input2Dims[i];
}
}
std::size_t mat1Size = rows1 * innerMulAxis;
std::size_t mat2Size = innerMulAxis * cols2;
std::size_t matSize = rows1 * cols2;
std::size_t nbMat = nbMat1 > nbMat2 ? nbMat1 : nbMat2;
// Multiplication
for (size_t i = 0; i < rows1; ++i) {
for (size_t j = 0; j < cols2; ++j) {
float sum = 0.0;
for (size_t k = 0; k < cols1; ++k) {
sum += input1[i * cols1 + k] * input2[k * cols2 + j];
}
output[i * cols2 + j] = sum;
}
}
for (std::size_t i = 0; i < nbMat; i++) {
// #pragma omp parallel for num_threads(8)
for (std::size_t m = 0; m < rows1; m++)
{
for (size_t k = 0; k < innerMulAxis; k++)
{
for (std::size_t n = 0; n < cols2; n++)
{
if (k==0) {
output[i * matSize + m * cols2 + n] = 0;
}
output[i * matSize + m * cols2 + n] += input1[(i%nbMat1) * mat1Size + m *innerMulAxis + k] * input2[(i%nbMat2)*mat2Size + k * cols2 + n];
}
}
}
}
}
namespace {
......
......@@ -32,13 +32,10 @@ void Aidge::MatMulImpl_cpu::forward()
{std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
kernelFunc(
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims(),
getCPUPtr(mOp.getRawInput(0)),
getCPUPtr(mOp.getRawInput(1)),
getCPUPtr(mOp.getRawOutput(0)));
}
......@@ -59,24 +59,28 @@ TEST_CASE("[cpu/operator] MatMul(forward)", "[MatMul][CPU]") {
}
SECTION("3D Tensor by 1D Tensor") {
std::shared_ptr<Tensor> input_1 = std::make_shared<Tensor>(Array3D<float,2,2,3> {
SECTION("3D Tensor by 2D Tensor") {
std::shared_ptr<Tensor> input_1 = std::make_shared<Tensor>(Array3D<float,1,2,3> {
{
{{0.82786506, 0.19047028, 0.62954658},
{0.63160968, 0.12468684, 0.49015969}},
{{0.49215794, 0.42231840, 0.02699018},
{0.66403216, 0.94622904, 0.42048711}}
{
{0.53427607, 0.69181818, 0.30088913},
{0.20866227, 0.67821276, 0.25695610}
}
}
});
std::shared_ptr<Tensor> input_2 = std::make_shared<Tensor>(Array1D<float,3>{
{0.82458717, 0.88598752, 0.78737932}
std::shared_ptr<Tensor> input_2 = std::make_shared<Tensor>(Array2D<float,3,4>{
{
{0.03158629, 0.21031839, 0.95692378, 0.05287921},
{0.66182911, 0.91662365, 0.07928377, 0.86983263},
{0.12386280, 0.63736272, 0.15963674, 0.465079722}
}
});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,2,2> {
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array3D<float,1,2,4> {
{
{1.34709311, 1.01722980},
{0.80124742, 1.71698236}
{
{0.51201022, 0.93828046, 0.61414438, 0.76995558},
{0.48727912, 0.82932562, 0.29446477, 0.72047055}
}
}
});
......@@ -99,27 +103,57 @@ TEST_CASE("[cpu/operator] MatMul(forward)", "[MatMul][CPU]") {
}
SECTION("3D Tensor by 2D Tensor") {
std::shared_ptr<Tensor> input_1 = std::make_shared<Tensor>(Array3D<float,1,2,3> {
SECTION("4D Tensors") {
std::shared_ptr<Tensor> input_1 = std::make_shared<Tensor>(Array4D<float,1,2,4,3> {
{
{
{0.53427607, 0.69181818, 0.30088913},
{0.20866227, 0.67821276, 0.25695610}
}
{
{0.78191108, 0.79929698, 0.45473319},
{0.35713595, 0.45651042, 0.40217435},
{0.15343380, 0.30024308, 0.78940034},
{0.53266525, 0.16684306, 0.22095734}
},
{
{0.89860427, 0.75139457, 0.34270161},
{0.53609246, 0.62800729, 0.68399906},
{0.57119054, 0.96259099, 0.71879345},
{0.73910689, 0.62526798, 0.77325356}
}
}
}
});
std::shared_ptr<Tensor> input_2 = std::make_shared<Tensor>(Array2D<float,3,4>{
std::shared_ptr<Tensor> input_2 = std::make_shared<Tensor>(Array4D<float,1,2,3,4>{
{
{0.03158629, 0.21031839, 0.95692378, 0.05287921},
{0.66182911, 0.91662365, 0.07928377, 0.86983263},
{0.12386280, 0.63736272, 0.15963674, 0.465079722}
{
{
{0.36525106, 0.47606337, 0.58315367, 0.33944082},
{0.56211257, 0.64100796, 0.28841895, 0.11285251},
{0.04657018, 0.21112120, 0.88220179, 0.23004770}
},
{
{0.33073467, 0.45434207, 0.92689610, 0.02250439},
{0.57044137, 0.88543379, 0.23575044, 0.57311541},
{0.21721125, 0.16826588, 0.45728493, 0.81760287}
}
}
}
});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array3D<float,1,2,4> {
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<float,1,2,4,4> {
{
{
{0.51201022, 0.93828046, 0.61414438, 0.76995558},
{0.48727912, 0.82932562, 0.29446477, 0.72047055}
{
{0.75606567, 0.98059881, 1.08767319, 0.46022552},
{0.40578386, 0.54755372, 0.69473034, 0.26526415},
{0.26157477, 0.43216154, 0.87248170, 0.26756462},
{0.29863116, 0.40717891, 0.55367535, 0.25046772}
},
{
{0.80026478, 1.13124883, 1.16676664, 0.73105216},
{0.68411803, 0.91472197, 0.95773751, 0.93122470},
{0.89414424, 1.23277485, 1.08505893, 1.15221763},
{0.76908636, 1.01955295, 1.18607962, 1.00719821}
}
}
}
});
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment