From cd558133569b535a69fbaf44715b966ce1fb0409 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Tue, 2 Apr 2024 12:06:16 +0200
Subject: [PATCH] Improved ReduceMean precision

---
 .../ReduceMeanImpl_forward_kernels.hpp        | 29 +++++++++----------
 1 file changed, 14 insertions(+), 15 deletions(-)

diff --git a/include/aidge/backend/cpu/operator/ReduceMeanImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/ReduceMeanImpl_forward_kernels.hpp
index d7a967e8..6533f7b1 100644
--- a/include/aidge/backend/cpu/operator/ReduceMeanImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/ReduceMeanImpl_forward_kernels.hpp
@@ -47,22 +47,23 @@ void ReduceMeanImpl_cpu_forward_kernel(const typename ReduceMean_Op::Attrs& attr
             for (std::size_t post = 0; post < stride_post; ++post) {
                 const std::size_t idx_i = pre * dim_i * stride_post + post;
                 const std::size_t idx_o = pre * stride_post + post;
-                output[idx_o] = input[idx_i];
-                for (std::size_t i = 1; i < dim_i; ++i) {
-                    output[idx_o] += input[idx_i + i*stride_post];
+                O mean = 0;
+                for (std::size_t i = 0; i < dim_i; ++i) {
+                    // Single pass numerically stable mean, using the fmaf
+                    mean = fmaf(input[idx_i + i*stride_post] - mean, 1.0f/(i+1), mean);
                 }
-                output[idx_o] /= dim_i;
+                output[idx_o]  = mean;
             }
         }
     } else {
         std::size_t outputElements = totalElements;
 
-        std::size_t *stride_post = new std::size_t[nb_dims];
+        auto stride_post = std::unique_ptr<std::size_t[]>(new std::size_t[nb_dims]);
         stride_post[nb_dims - 1] = 1;
         for (std::size_t i = nb_dims-2; i != static_cast<std::size_t>(-1); --i) {
             stride_post[i] = stride_post[i+1]*inputDims[i+1];
         }
-        std::size_t *stride_pre = new std::size_t[nb_dims];
+        auto stride_pre = std::unique_ptr<std::size_t[]>(new std::size_t[nb_dims]);
         stride_pre[0] = 1;
         for (std::size_t i = 1; i < nb_dims; ++i) {
             stride_pre[i] = stride_pre[i-1]*inputDims[i-1];
@@ -80,13 +81,15 @@ void ReduceMeanImpl_cpu_forward_kernel(const typename ReduceMean_Op::Attrs& attr
                 for (std::size_t post = 0; post < stride_post[a]; ++post) {
                     const std::size_t idx_i = pre * dim_i * stride_post[a] + post;
                     const std::size_t idx_o = pre * stride_post[a] + post;
-                    outputAccumulation[idx_o] = inputAccumulation[idx_i];
-                    for (std::size_t i = 1; i < dim_i; ++i) {
-                        outputAccumulation[idx_o] += inputAccumulation[idx_i + i*stride_post[a]];
+                    I mean = 0;
+                    for (std::size_t i = 0; i < dim_i; ++i) {
+                        // Single pass numerically stable mean, using the fmaf
+                        mean = fmaf(inputAccumulation[idx_i + i*stride_post[a]] - mean, 1.0f/(i+1), mean);
                     }
+                    outputAccumulation[idx_o] = mean;
                 }
             }
-            std::for_each(stride_pre+a+1, stride_pre+nb_dims, [dim_i] (std::size_t& val) { val /= dim_i; });
+            std::for_each(stride_pre.get()+a+1, stride_pre.get()+nb_dims, [dim_i] (std::size_t& val) { val /= dim_i; });
             if (inputAccumulation != input) {
                 delete[] inputAccumulation;
             }
@@ -94,14 +97,10 @@ void ReduceMeanImpl_cpu_forward_kernel(const typename ReduceMean_Op::Attrs& attr
         }
 
         // Copy elements from inputAccumulation to output while dividing by divisor
-        I divisor = totalElements / outputElements;
-        std::transform(inputAccumulation, inputAccumulation + outputElements, output,
-                    [divisor](I element) { return element / divisor; });
+        std::copy(inputAccumulation, inputAccumulation + outputElements, output);
         if (outputAccumulation) {
             delete[] outputAccumulation;
         }
-        delete[] stride_post;
-        delete[] stride_pre;
     }
 }
 
-- 
GitLab