diff --git a/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl_kernels.hpp b/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl_kernels.hpp index 4a95c3a4152dba1f19a91bcc4339c80cc90ed086..d5e5561d02aacd8532f74d2bfd4ee2fb5a5b5dc3 100644 --- a/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl_kernels.hpp +++ b/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl_kernels.hpp @@ -40,9 +40,9 @@ stableMean(const T* vec, size_t size) { template <typename T> typename std::enable_if<!std::is_floating_point<T>::value, T>::type stableMean(const T* vec, size_t size) { - float mean = 0; + double mean = 0; for (size_t i = 0; i < size; ++i) { - mean = std::fma<float>(vec[i] - mean, 1.0f / (i + 1), mean); + mean = std::fma<double>(vec[i] - mean, 1.0f / (i + 1), mean); } return mean; } @@ -55,7 +55,7 @@ castFromFloat(T value) { template <typename T> typename std::enable_if<!std::is_floating_point<T>::value, T>::type -castFromFloat(float value) { +castFromFloat(double value) { return static_cast<T>(std::nearbyint(value)); } diff --git a/include/aidge/backend/cpu/operator/ReduceMeanImpl_kernels.hpp b/include/aidge/backend/cpu/operator/ReduceMeanImpl_kernels.hpp index c0962408738c51cca7f725ef519704c845c8c7bc..864b89c4fa4667b70e43ed7436382e30bc150745 100644 --- a/include/aidge/backend/cpu/operator/ReduceMeanImpl_kernels.hpp +++ b/include/aidge/backend/cpu/operator/ReduceMeanImpl_kernels.hpp @@ -40,9 +40,9 @@ stableMean(const T* vec, size_t len, size_t stride) { template <typename T> typename std::enable_if<!std::is_floating_point<T>::value, T>::type stableMean(const T* vec, size_t len, size_t stride) { - float mean = 0; + double mean = 0; for (size_t i = 0; i < len; ++i) { - mean = std::fma<float>(vec[i * stride] - mean, 1.0f / (i + 1), mean); + mean = std::fma<double>(vec[i * stride] - mean, 1.0f / (i + 1), mean); } return mean; } @@ -55,7 +55,7 @@ castFromFloat(T value) { template <typename T> typename std::enable_if<!std::is_floating_point<T>::value, T>::type -castFromFloat(float value) { +castFromFloat(double value) { return static_cast<T>(std::nearbyint(value)); }