From 204f80b0d33bfff157ca724b28eecc2b83c97135 Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Wed, 31 Jan 2024 16:44:39 +0100
Subject: [PATCH] fix matmul to support more matrix shapes

---
 .../operator/MatMulImpl_forward_kernels.hpp   | 71 ++++++++++------
 src/operator/MatMulImpl.cpp                   |  3 -
 unit_tests/operator/Test_MatMulImpl.cpp       | 84 +++++++++++++------
 3 files changed, 105 insertions(+), 53 deletions(-)

diff --git a/include/aidge/backend/cpu/operator/MatMulImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/MatMulImpl_forward_kernels.hpp
index 92bc5a61..a63e9e12 100644
--- a/include/aidge/backend/cpu/operator/MatMulImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/MatMulImpl_forward_kernels.hpp
@@ -15,6 +15,7 @@
 #include "aidge/utils/Registrar.hpp"
 #include <algorithm>
 
+// #include <omp.h>
 #include "aidge/backend/cpu/operator/MatMulImpl.hpp"
 
 namespace Aidge {
@@ -26,35 +27,55 @@ void MatMulImpl_cpu_forward_kernel(const std::vector<DimSize_t>& input1Dims,cons
     const I* input1 = static_cast<const I*>(input1_);
     const I* input2 = static_cast<const I*>(input2_);
     O* output = static_cast<O*>(output_);
-	size_t secondToLastIdx1 = input1Dims.size() > 1 ? input1Dims.size() - 2 : 0;
+
 	size_t secondToLastIdx2 = input2Dims.size() > 1 ? input2Dims.size() - 2 : 0;
 	// Checking if matrix dimensions are compatible for multiplication
-	assert(input1Dims.back() == input2Dims[secondToLastIdx2] && "Matrix dimensions are not compatible for multiplication");
-
-    // Extracting dimensions
-    size_t rows1 = 1, cols1 = 1,  cols2 = 1;
-
-    // For input1
-    for (size_t i = 0; i < input1Dims.size() - 1; ++i) {
-        rows1 *= input1Dims[i];
-    }
-    cols1 = input1Dims.back();
+	assert(input1Dims[input1Dims.size()-1] == input2Dims[secondToLastIdx2] &&
+            "Matrix dimensions are not compatible for multiplication");
 
-    // For input2
-    for (size_t i = 1; i < input2Dims.size(); ++i) {
-        cols2 *= input2Dims[i];
-    }
+	std::size_t innerMulAxis = input1Dims[input1Dims.size()-1];
+	std::size_t rows1 = input1Dims[input1Dims.size()-2];
+	std::size_t cols2 = input2Dims[input2Dims.size()-1];
+	std::size_t nbMat1 = 1, nbMat2 = 1;
+	if (input1Dims.size()>2)
+	{
+		for (std::size_t i = 0; i < input1Dims.size()-2; i++)
+		{
+			nbMat1 *= input1Dims[i];
+		}
+		
+	}
+	if (input2Dims.size()>2)
+	{
+		for (std::size_t i = 0; i < input2Dims.size()-2; i++)
+		{
+			nbMat2 *= input2Dims[i];
+		}
+		
+	}
+	std::size_t mat1Size = rows1 * innerMulAxis;
+	std::size_t mat2Size = innerMulAxis * cols2;
+	std::size_t matSize = rows1 * cols2;
+	std::size_t nbMat = nbMat1 > nbMat2 ? nbMat1 : nbMat2;
 
-    // Multiplication
-    for (size_t i = 0; i < rows1; ++i) {
-        for (size_t j = 0; j < cols2; ++j) {
-            float sum = 0.0;
-            for (size_t k = 0; k < cols1; ++k) {
-                sum += input1[i * cols1 + k] * input2[k * cols2 + j];
-            }
-            output[i * cols2 + j] = sum;
-        }
-    }
+	for (std::size_t i = 0; i < nbMat; i++) {
+// #pragma omp parallel for num_threads(8)
+		for (std::size_t m = 0; m < rows1; m++)
+		{
+			
+			for (size_t k = 0; k < innerMulAxis; k++)
+			{
+				for (std::size_t n = 0; n < cols2; n++)
+				{
+                    if (k==0) {
+                        output[i * matSize + m * cols2 + n]  = 0;
+                    }
+                    
+					output[i * matSize + m * cols2 + n] += input1[(i%nbMat1) * mat1Size + m *innerMulAxis + k] * input2[(i%nbMat2)*mat2Size + k * cols2 + n];	
+				}
+			}	
+		}
+	}
 }
 
 namespace {
diff --git a/src/operator/MatMulImpl.cpp b/src/operator/MatMulImpl.cpp
index c1c3ccb0..5818ac64 100644
--- a/src/operator/MatMulImpl.cpp
+++ b/src/operator/MatMulImpl.cpp
@@ -32,13 +32,10 @@ void Aidge::MatMulImpl_cpu::forward()
         {std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
          std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
 
-
     kernelFunc(
         std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(),
         std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims(),
         getCPUPtr(mOp.getRawInput(0)),
         getCPUPtr(mOp.getRawInput(1)),
         getCPUPtr(mOp.getRawOutput(0)));
-
-
 }
diff --git a/unit_tests/operator/Test_MatMulImpl.cpp b/unit_tests/operator/Test_MatMulImpl.cpp
index ae10df27..abb9227a 100644
--- a/unit_tests/operator/Test_MatMulImpl.cpp
+++ b/unit_tests/operator/Test_MatMulImpl.cpp
@@ -59,24 +59,28 @@ TEST_CASE("[cpu/operator] MatMul(forward)", "[MatMul][CPU]") {
 
     }
 
-
-    SECTION("3D Tensor by 1D Tensor") {
-        std::shared_ptr<Tensor> input_1 = std::make_shared<Tensor>(Array3D<float,2,2,3> {
+    SECTION("3D Tensor by 2D Tensor") {
+        std::shared_ptr<Tensor> input_1 = std::make_shared<Tensor>(Array3D<float,1,2,3> {
             {
-                {{0.82786506, 0.19047028, 0.62954658},
-         		 {0.63160968, 0.12468684, 0.49015969}},
-
-        		{{0.49215794, 0.42231840, 0.02699018},
-        		 {0.66403216, 0.94622904, 0.42048711}}
+                {
+					{0.53427607, 0.69181818, 0.30088913},
+         		 	{0.20866227, 0.67821276, 0.25695610}
+				}
             }
         });
-        std::shared_ptr<Tensor> input_2 =  std::make_shared<Tensor>(Array1D<float,3>{
-            {0.82458717, 0.88598752, 0.78737932}
+        std::shared_ptr<Tensor> input_2 =  std::make_shared<Tensor>(Array2D<float,3,4>{
+            {
+				{0.03158629, 0.21031839, 0.95692378, 0.05287921},
+				{0.66182911, 0.91662365, 0.07928377, 0.86983263},
+				{0.12386280, 0.63736272, 0.15963674, 0.465079722}
+			}
         });
-        std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,2,2> {
+        std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array3D<float,1,2,4> {
             {
-                {1.34709311, 1.01722980},
-        		{0.80124742, 1.71698236}
+                {
+					{0.51201022, 0.93828046, 0.61414438, 0.76995558},
+         			{0.48727912, 0.82932562, 0.29446477, 0.72047055}
+				}
             }
         });
 
@@ -99,27 +103,57 @@ TEST_CASE("[cpu/operator] MatMul(forward)", "[MatMul][CPU]") {
 
     }
 
-    SECTION("3D Tensor by 2D Tensor") {
-        std::shared_ptr<Tensor> input_1 = std::make_shared<Tensor>(Array3D<float,1,2,3> {
+
+    SECTION("4D Tensors") {
+        std::shared_ptr<Tensor> input_1 = std::make_shared<Tensor>(Array4D<float,1,2,4,3> {
             {
                 {
-					{0.53427607, 0.69181818, 0.30088913},
-         		 	{0.20866227, 0.67821276, 0.25695610}
-				}
+                    {
+                        {0.78191108, 0.79929698, 0.45473319},
+                        {0.35713595, 0.45651042, 0.40217435},
+                        {0.15343380, 0.30024308, 0.78940034},
+                        {0.53266525, 0.16684306, 0.22095734}
+                    },
+                    {
+                        {0.89860427, 0.75139457, 0.34270161},
+                        {0.53609246, 0.62800729, 0.68399906},
+                        {0.57119054, 0.96259099, 0.71879345},   
+                        {0.73910689, 0.62526798, 0.77325356}
+                    }
+                }
             }
         });
-        std::shared_ptr<Tensor> input_2 =  std::make_shared<Tensor>(Array2D<float,3,4>{
+        std::shared_ptr<Tensor> input_2 =  std::make_shared<Tensor>(Array4D<float,1,2,3,4>{
             {
-				{0.03158629, 0.21031839, 0.95692378, 0.05287921},
-				{0.66182911, 0.91662365, 0.07928377, 0.86983263},
-				{0.12386280, 0.63736272, 0.15963674, 0.465079722}
+                {
+                    {
+                        {0.36525106, 0.47606337, 0.58315367, 0.33944082},
+                        {0.56211257, 0.64100796, 0.28841895, 0.11285251},
+                        {0.04657018, 0.21112120, 0.88220179, 0.23004770}
+                    },
+                    {
+                        {0.33073467, 0.45434207, 0.92689610, 0.02250439},
+                        {0.57044137, 0.88543379, 0.23575044, 0.57311541},
+                        {0.21721125, 0.16826588, 0.45728493, 0.81760287}
+                    }
+                }
 			}
         });
-        std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array3D<float,1,2,4> {
+        std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<float,1,2,4,4> {
             {
                 {
-					{0.51201022, 0.93828046, 0.61414438, 0.76995558},
-         			{0.48727912, 0.82932562, 0.29446477, 0.72047055}
+					{
+                        {0.75606567, 0.98059881, 1.08767319, 0.46022552},
+                        {0.40578386, 0.54755372, 0.69473034, 0.26526415},
+                        {0.26157477, 0.43216154, 0.87248170, 0.26756462},
+                        {0.29863116, 0.40717891, 0.55367535, 0.25046772}
+                    },
+                    {
+                        {0.80026478, 1.13124883, 1.16676664, 0.73105216},
+                        {0.68411803, 0.91472197, 0.95773751, 0.93122470},
+                        {0.89414424, 1.23277485, 1.08505893, 1.15221763},
+                        {0.76908636, 1.01955295, 1.18607962, 1.00719821}
+                    }
 				}
             }
         });
-- 
GitLab