From 9eadee681c5f6ba16ab68657582fb3878e82ab37 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Thu, 6 Feb 2025 15:44:32 +0100
Subject: [PATCH] Use double instead of float to process integers

---
 .../cpu/operator/GlobalAveragePoolingImpl_kernels.hpp       | 6 +++---
 .../aidge/backend/cpu/operator/ReduceMeanImpl_kernels.hpp   | 6 +++---
 2 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl_kernels.hpp b/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl_kernels.hpp
index 4a95c3a4..d5e5561d 100644
--- a/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl_kernels.hpp
@@ -40,9 +40,9 @@ stableMean(const T* vec, size_t size) {
 template <typename T>
 typename std::enable_if<!std::is_floating_point<T>::value, T>::type
 stableMean(const T* vec, size_t size) {
-  float mean = 0;
+  double mean = 0;
   for (size_t i = 0; i < size; ++i) {
-    mean = std::fma<float>(vec[i] - mean, 1.0f / (i + 1), mean);
+    mean = std::fma<double>(vec[i] - mean, 1.0f / (i + 1), mean);
   }
   return mean;
 }
@@ -55,7 +55,7 @@ castFromFloat(T value) {
 
 template <typename T>
 typename std::enable_if<!std::is_floating_point<T>::value, T>::type
-castFromFloat(float value) {
+castFromFloat(double value) {
   return static_cast<T>(std::nearbyint(value));
 }
 
diff --git a/include/aidge/backend/cpu/operator/ReduceMeanImpl_kernels.hpp b/include/aidge/backend/cpu/operator/ReduceMeanImpl_kernels.hpp
index c0962408..864b89c4 100644
--- a/include/aidge/backend/cpu/operator/ReduceMeanImpl_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/ReduceMeanImpl_kernels.hpp
@@ -40,9 +40,9 @@ stableMean(const T* vec, size_t len, size_t stride) {
 template <typename T>
 typename std::enable_if<!std::is_floating_point<T>::value, T>::type
 stableMean(const T* vec, size_t len, size_t stride) {
-  float mean = 0;
+  double mean = 0;
   for (size_t i = 0; i < len; ++i) {
-    mean = std::fma<float>(vec[i * stride] - mean, 1.0f / (i + 1), mean);
+    mean = std::fma<double>(vec[i * stride] - mean, 1.0f / (i + 1), mean);
   }
   return mean;
 }
@@ -55,7 +55,7 @@ castFromFloat(T value) {
 
 template <typename T>
 typename std::enable_if<!std::is_floating_point<T>::value, T>::type
-castFromFloat(float value) {
+castFromFloat(double value) {
   return static_cast<T>(std::nearbyint(value));
 }
 
-- 
GitLab