diff --git a/include/aidge/backend/cpu/operator/MatMulImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/MatMulImpl_forward_kernels.hpp
index a63e9e124596fa529b53068fb6a589ebf42f5f55..9eae167de0ad02eef79510d0e18b6e43ce660d74 100644
--- a/include/aidge/backend/cpu/operator/MatMulImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/MatMulImpl_forward_kernels.hpp
@@ -15,7 +15,6 @@
 #include "aidge/utils/Registrar.hpp"
 #include <algorithm>
 
-// #include <omp.h>
 #include "aidge/backend/cpu/operator/MatMulImpl.hpp"
 
 namespace Aidge {
@@ -23,55 +22,43 @@ namespace Aidge {
 template <class I, class O>
 void MatMulImpl_cpu_forward_kernel(const std::vector<DimSize_t>& input1Dims,const std::vector<DimSize_t>& input2Dims,
                                    const void* input1_, const void* input2_, void* output_) {
-    // FIXME: missing MatMul parameters as arguments
     const I* input1 = static_cast<const I*>(input1_);
     const I* input2 = static_cast<const I*>(input2_);
     O* output = static_cast<O*>(output_);
 
-	size_t secondToLastIdx2 = input2Dims.size() > 1 ? input2Dims.size() - 2 : 0;
+	assert((input1Dims.size()>1 && input2Dims.size()>1) && "Inputs must be at least 2D for MatMul");
 	// Checking if matrix dimensions are compatible for multiplication
-	assert(input1Dims[input1Dims.size()-1] == input2Dims[secondToLastIdx2] &&
+	assert(input1Dims[input1Dims.size()-1] == input2Dims[input2Dims.size() - 2] &&
             "Matrix dimensions are not compatible for multiplication");
 
 	std::size_t innerMulAxis = input1Dims[input1Dims.size()-1];
 	std::size_t rows1 = input1Dims[input1Dims.size()-2];
 	std::size_t cols2 = input2Dims[input2Dims.size()-1];
+	// Compute number of 2D matrices in input1 and input2
 	std::size_t nbMat1 = 1, nbMat2 = 1;
-	if (input1Dims.size()>2)
-	{
-		for (std::size_t i = 0; i < input1Dims.size()-2; i++)
-		{
-			nbMat1 *= input1Dims[i];
-		}
-		
+	for (std::size_t i = 0; i < input1Dims.size()-2; i++) {
+		nbMat1 *= input1Dims[i];
 	}
-	if (input2Dims.size()>2)
-	{
-		for (std::size_t i = 0; i < input2Dims.size()-2; i++)
-		{
-			nbMat2 *= input2Dims[i];
-		}
-		
+	for (std::size_t i = 0; i < input2Dims.size()-2; i++) {
+		nbMat2 *= input2Dims[i];
 	}
+
 	std::size_t mat1Size = rows1 * innerMulAxis;
 	std::size_t mat2Size = innerMulAxis * cols2;
 	std::size_t matSize = rows1 * cols2;
-	std::size_t nbMat = nbMat1 > nbMat2 ? nbMat1 : nbMat2;
+	std::size_t nbMat = std::max(nbMat1, nbMat2);
 
 	for (std::size_t i = 0; i < nbMat; i++) {
-// #pragma omp parallel for num_threads(8)
-		for (std::size_t m = 0; m < rows1; m++)
-		{
-			
-			for (size_t k = 0; k < innerMulAxis; k++)
-			{
-				for (std::size_t n = 0; n < cols2; n++)
-				{
+		for (std::size_t m = 0; m < rows1; m++) {
+			for (size_t k = 0; k < innerMulAxis; k++) {
+				const std::size_t input1Idx = (i % nbMat1) * mat1Size + m * innerMulAxis + k;
+				for (std::size_t n = 0; n < cols2; n++) {
+					const std::size_t outputIdx = i * matSize + m * cols2 + n;
+					const std::size_t input2Idx = (i % nbMat2) * mat2Size + k * cols2 + n;
                     if (k==0) {
-                        output[i * matSize + m * cols2 + n]  = 0;
+                        output[outputIdx]  = 0;
                     }
-                    
-					output[i * matSize + m * cols2 + n] += input1[(i%nbMat1) * mat1Size + m *innerMulAxis + k] * input2[(i%nbMat2)*mat2Size + k * cols2 + n];	
+					output[outputIdx] += input1[input1Idx] * input2[input2Idx];
 				}
 			}	
 		}