Skip to content
Snippets Groups Projects
Commit f70ba8c1 authored by Houssem ROUIS's avatar Houssem ROUIS Committed by Maxence Naud
Browse files

fix matmul to support more matrix shapes

parent fc36e10e
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!34Matmul rework
......@@ -15,6 +15,7 @@
#include "aidge/utils/Registrar.hpp"
#include <algorithm>
// #include <omp.h>
#include "aidge/backend/cpu/operator/MatMulImpl.hpp"
namespace Aidge {
......@@ -26,35 +27,55 @@ void MatMulImpl_cpu_forward_kernel(const std::vector<DimSize_t>& input1Dims,cons
const I* input1 = static_cast<const I*>(input1_);
const I* input2 = static_cast<const I*>(input2_);
O* output = static_cast<O*>(output_);
size_t secondToLastIdx1 = input1Dims.size() > 1 ? input1Dims.size() - 2 : 0;
size_t secondToLastIdx2 = input2Dims.size() > 1 ? input2Dims.size() - 2 : 0;
// Checking if matrix dimensions are compatible for multiplication
assert(input1Dims.back() == input2Dims[secondToLastIdx2] && "Matrix dimensions are not compatible for multiplication");
// Extracting dimensions
size_t rows1 = 1, cols1 = 1, cols2 = 1;
// For input1
for (size_t i = 0; i < input1Dims.size() - 1; ++i) {
rows1 *= input1Dims[i];
}
cols1 = input1Dims.back();
assert(input1Dims[input1Dims.size()-1] == input2Dims[secondToLastIdx2] &&
"Matrix dimensions are not compatible for multiplication");
// For input2
for (size_t i = 1; i < input2Dims.size(); ++i) {
cols2 *= input2Dims[i];
}
std::size_t innerMulAxis = input1Dims[input1Dims.size()-1];
std::size_t rows1 = input1Dims[input1Dims.size()-2];
std::size_t cols2 = input2Dims[input2Dims.size()-1];
std::size_t nbMat1 = 1, nbMat2 = 1;
if (input1Dims.size()>2)
{
for (std::size_t i = 0; i < input1Dims.size()-2; i++)
{
nbMat1 *= input1Dims[i];
}
}
if (input2Dims.size()>2)
{
for (std::size_t i = 0; i < input2Dims.size()-2; i++)
{
nbMat2 *= input2Dims[i];
}
}
std::size_t mat1Size = rows1 * innerMulAxis;
std::size_t mat2Size = innerMulAxis * cols2;
std::size_t matSize = rows1 * cols2;
std::size_t nbMat = nbMat1 > nbMat2 ? nbMat1 : nbMat2;
// Multiplication
for (size_t i = 0; i < rows1; ++i) {
for (size_t j = 0; j < cols2; ++j) {
float sum = 0.0;
for (size_t k = 0; k < cols1; ++k) {
sum += input1[i * cols1 + k] * input2[k * cols2 + j];
}
output[i * cols2 + j] = sum;
}
}
for (std::size_t i = 0; i < nbMat; i++) {
// #pragma omp parallel for num_threads(8)
for (std::size_t m = 0; m < rows1; m++)
{
for (size_t k = 0; k < innerMulAxis; k++)
{
for (std::size_t n = 0; n < cols2; n++)
{
if (k==0) {
output[i * matSize + m * cols2 + n] = 0;
}
output[i * matSize + m * cols2 + n] += input1[(i%nbMat1) * mat1Size + m *innerMulAxis + k] * input2[(i%nbMat2)*mat2Size + k * cols2 + n];
}
}
}
}
}
namespace {
......
......@@ -32,13 +32,10 @@ void Aidge::MatMulImpl_cpu::forward()
{std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
kernelFunc(
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims(),
getCPUPtr(mOp.getRawInput(0)),
getCPUPtr(mOp.getRawInput(1)),
getCPUPtr(mOp.getRawOutput(0)));
}
......@@ -59,24 +59,28 @@ TEST_CASE("[cpu/operator] MatMul(forward)", "[MatMul][CPU]") {
}
SECTION("3D Tensor by 1D Tensor") {
std::shared_ptr<Tensor> input_1 = std::make_shared<Tensor>(Array3D<float,2,2,3> {
SECTION("3D Tensor by 2D Tensor") {
std::shared_ptr<Tensor> input_1 = std::make_shared<Tensor>(Array3D<float,1,2,3> {
{
{{0.82786506, 0.19047028, 0.62954658},
{0.63160968, 0.12468684, 0.49015969}},
{{0.49215794, 0.42231840, 0.02699018},
{0.66403216, 0.94622904, 0.42048711}}
{
{0.53427607, 0.69181818, 0.30088913},
{0.20866227, 0.67821276, 0.25695610}
}
}
});
std::shared_ptr<Tensor> input_2 = std::make_shared<Tensor>(Array1D<float,3>{
{0.82458717, 0.88598752, 0.78737932}
std::shared_ptr<Tensor> input_2 = std::make_shared<Tensor>(Array2D<float,3,4>{
{
{0.03158629, 0.21031839, 0.95692378, 0.05287921},
{0.66182911, 0.91662365, 0.07928377, 0.86983263},
{0.12386280, 0.63736272, 0.15963674, 0.465079722}
}
});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,2,2> {
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array3D<float,1,2,4> {
{
{1.34709311, 1.01722980},
{0.80124742, 1.71698236}
{
{0.51201022, 0.93828046, 0.61414438, 0.76995558},
{0.48727912, 0.82932562, 0.29446477, 0.72047055}
}
}
});
......@@ -99,27 +103,57 @@ TEST_CASE("[cpu/operator] MatMul(forward)", "[MatMul][CPU]") {
}
SECTION("3D Tensor by 2D Tensor") {
std::shared_ptr<Tensor> input_1 = std::make_shared<Tensor>(Array3D<float,1,2,3> {
SECTION("4D Tensors") {
std::shared_ptr<Tensor> input_1 = std::make_shared<Tensor>(Array4D<float,1,2,4,3> {
{
{
{0.53427607, 0.69181818, 0.30088913},
{0.20866227, 0.67821276, 0.25695610}
}
{
{0.78191108, 0.79929698, 0.45473319},
{0.35713595, 0.45651042, 0.40217435},
{0.15343380, 0.30024308, 0.78940034},
{0.53266525, 0.16684306, 0.22095734}
},
{
{0.89860427, 0.75139457, 0.34270161},
{0.53609246, 0.62800729, 0.68399906},
{0.57119054, 0.96259099, 0.71879345},
{0.73910689, 0.62526798, 0.77325356}
}
}
}
});
std::shared_ptr<Tensor> input_2 = std::make_shared<Tensor>(Array2D<float,3,4>{
std::shared_ptr<Tensor> input_2 = std::make_shared<Tensor>(Array4D<float,1,2,3,4>{
{
{0.03158629, 0.21031839, 0.95692378, 0.05287921},
{0.66182911, 0.91662365, 0.07928377, 0.86983263},
{0.12386280, 0.63736272, 0.15963674, 0.465079722}
{
{
{0.36525106, 0.47606337, 0.58315367, 0.33944082},
{0.56211257, 0.64100796, 0.28841895, 0.11285251},
{0.04657018, 0.21112120, 0.88220179, 0.23004770}
},
{
{0.33073467, 0.45434207, 0.92689610, 0.02250439},
{0.57044137, 0.88543379, 0.23575044, 0.57311541},
{0.21721125, 0.16826588, 0.45728493, 0.81760287}
}
}
}
});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array3D<float,1,2,4> {
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<float,1,2,4,4> {
{
{
{0.51201022, 0.93828046, 0.61414438, 0.76995558},
{0.48727912, 0.82932562, 0.29446477, 0.72047055}
{
{0.75606567, 0.98059881, 1.08767319, 0.46022552},
{0.40578386, 0.54755372, 0.69473034, 0.26526415},
{0.26157477, 0.43216154, 0.87248170, 0.26756462},
{0.29863116, 0.40717891, 0.55367535, 0.25046772}
},
{
{0.80026478, 1.13124883, 1.16676664, 0.73105216},
{0.68411803, 0.91472197, 0.95773751, 0.93122470},
{0.89414424, 1.23277485, 1.08505893, 1.15221763},
{0.76908636, 1.01955295, 1.18607962, 1.00719821}
}
}
}
});
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment