From 8e7e8c69ecc03070b430c459eb67f8760cc7e871 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Sun, 15 Sep 2024 16:17:42 +0200
Subject: [PATCH] Basic MatMul optimization

---
 .../backend/cpu/operator/MatMulImpl_kernels.hpp    | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/include/aidge/backend/cpu/operator/MatMulImpl_kernels.hpp b/include/aidge/backend/cpu/operator/MatMulImpl_kernels.hpp
index 7cb1239e..088c89e6 100644
--- a/include/aidge/backend/cpu/operator/MatMulImpl_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/MatMulImpl_kernels.hpp
@@ -18,19 +18,19 @@ namespace Aidge {
 
 template <class I, class O>
 void MatMulImpl_cpu_forward_kernel(const std::size_t n, const std::size_t k, const std::size_t m,
-                                    const void* input1_, const void* input2_, void* output_) {
+                                    const void* input1_, const void* input2_, void* __restrict__ output_) {
     // FIXME: missing MatMul parameters as arguments
     const I* input1 = static_cast<const I*>(input1_);
     const I* input2 = static_cast<const I*>(input2_);
-    O* output = static_cast<O*>(output_);
+    O* __restrict__ output = static_cast<O* __restrict__>(output_);
+
+    std::memset(output, O(0), n * m * sizeof(O));
 
     for (std::size_t i = 0; i < n; ++i) {
-        for (std::size_t j = 0; j < m; ++j) {
-            O sum = O(0);
-            for (std::size_t l = 0; l < k; ++l) {
-                sum += static_cast<O>(input1[i*k + l] * input2[l*m + j]);
+        for (std::size_t l = 0; l < k; ++l) {
+            for (std::size_t j = 0; j < m; ++j) {
+                output[i*m + j] += static_cast<O>(input1[i*k + l] * input2[l*m + j]);
             }
-            output[i*m + j] = sum;
         }
     }
 }
-- 
GitLab