diff --git a/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl_kernels.hpp b/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl_kernels.hpp index 40dd3a6932541475e7315dbb6829f99429b030ba..cbe4f110fc74f387625132c4f0872123814c1a62 100644 --- a/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl_kernels.hpp +++ b/include/aidge/backend/cpu/operator/GlobalAveragePoolingImpl_kernels.hpp @@ -43,6 +43,18 @@ static stableMean(const T* vec, std::size_t size) { return mean; } +template <typename T> +typename std::enable_if_t<std::is_floating_point<T>::value, T> +static castFromFloat(T value) { + return value; +} + +template <typename T> +typename std::enable_if_t<!std::is_floating_point<T>::value, T> +static castFromFloat(double value) { + return static_cast<T>(std::nearbyint(value)); +} + template <DataType DT_I, DataType DT_O = DT_I> void GlobalAveragePoolingImpl_cpu_forward_kernel(const std::shared_ptr<Tensor>& inputTensor, void *output_) { @@ -61,7 +73,7 @@ void GlobalAveragePoolingImpl_cpu_forward_kernel(const std::shared_ptr<Tensor>& std::size_t output_idx = 0; for (DimSize_t batch = 0; batch < dims[0]; ++batch) { for (DimSize_t channel = 0; channel < dims[1]; ++channel) { - output[output_idx++] = static_cast<O>(stableMean<I>(input + input_idx, strides_channels)); + output[output_idx++] = castFromFloat<O>(stableMean<I>(input + input_idx, strides_channels)); input_idx += strides_channels; } }