Skip to content
Snippets Groups Projects
Commit cd824d73 authored by Maxence Naud's avatar Maxence Naud
Browse files

fix: move back 'castFromFloat' function in GloabalAvgPoolingImpl_kernel as...

fix: move back 'castFromFloat' function in GloabalAvgPoolingImpl_kernel as 'static_cast' truncates and does not round as intended
parent 9fea6486
No related branches found
No related tags found
2 merge requests!166Update 0.5.0 -> 0.6.0,!152Fix compilation on MacOS
Pipeline #69517 passed
......@@ -43,6 +43,18 @@ static stableMean(const T* vec, std::size_t size) {
return mean;
}
template <typename T>
typename std::enable_if_t<std::is_floating_point<T>::value, T>
static castFromFloat(T value) {
return value;
}
template <typename T>
typename std::enable_if_t<!std::is_floating_point<T>::value, T>
static castFromFloat(double value) {
return static_cast<T>(std::nearbyint(value));
}
template <DataType DT_I, DataType DT_O = DT_I>
void GlobalAveragePoolingImpl_cpu_forward_kernel(const std::shared_ptr<Tensor>& inputTensor, void *output_) {
......@@ -61,7 +73,7 @@ void GlobalAveragePoolingImpl_cpu_forward_kernel(const std::shared_ptr<Tensor>&
std::size_t output_idx = 0;
for (DimSize_t batch = 0; batch < dims[0]; ++batch) {
for (DimSize_t channel = 0; channel < dims[1]; ++channel) {
output[output_idx++] = static_cast<O>(stableMean<I>(input + input_idx, strides_channels));
output[output_idx++] = castFromFloat<O>(stableMean<I>(input + input_idx, strides_channels));
input_idx += strides_channels;
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment