Skip to content
Snippets Groups Projects
Commit a8400879 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added output rounding

parent 2bb9f35b
No related branches found
No related tags found
2 merge requests!135[Upd] Patch v0.5.1,!133Fix bug eclipse/aidge/aidge_backend_cpu#40
......@@ -47,6 +47,18 @@ stableMean(const T* vec, size_t size) {
return mean;
}
template <typename T>
typename std::enable_if<std::is_floating_point<T>::value, T>::type
castFromFloat(T value) {
return value;
}
template <typename T>
typename std::enable_if<!std::is_floating_point<T>::value, T>::type
castFromFloat(float value) {
return static_cast<T>(std::nearbyint(value));
}
template <class I, class O>
void GlobalAveragePoolingImpl_cpu_forward_kernel(
const std::vector<DimSize_t> &dims, const void *input_, void *output_) {
......@@ -71,7 +83,7 @@ void GlobalAveragePoolingImpl_cpu_forward_kernel(
for (DimSize_t channel = 0; channel < dims[1]; ++channel) {
const I *filter_start = std::next(
input, (batch * in_batch_nb_elems) + (channel * in_channel_nb_elems));
output[batch * out_batch_nb_elems + channel] = stableMean<I>(filter_start, in_channel_nb_elems);
output[batch * out_batch_nb_elems + channel] = castFromFloat<O>(stableMean<I>(filter_start, in_channel_nb_elems));
}
}
}
......
......@@ -47,6 +47,18 @@ stableMean(const T* vec, size_t len, size_t stride) {
return mean;
}
template <typename T>
typename std::enable_if<std::is_floating_point<T>::value, T>::type
castFromFloat(T value) {
return value;
}
template <typename T>
typename std::enable_if<!std::is_floating_point<T>::value, T>::type
castFromFloat(float value) {
return static_cast<T>(std::nearbyint(value));
}
template <class I, class O>
void ReduceMeanImpl_cpu_forward_kernel(const std::vector<std::int32_t>& axes,
DimSize_t /*keepDims*/,
......@@ -72,7 +84,7 @@ void ReduceMeanImpl_cpu_forward_kernel(const std::vector<std::int32_t>& axes,
for (std::size_t post = 0; post < stride_post; ++post) {
const std::size_t idx_i = pre * dim_i * stride_post + post;
const std::size_t idx_o = pre * stride_post + post;
output[idx_o] = stableMean<I>(input + idx_i, dim_i, stride_post);
output[idx_o] = castFromFloat<O>(stableMean(input + idx_i, dim_i, stride_post));
}
}
} else {
......@@ -89,8 +101,9 @@ void ReduceMeanImpl_cpu_forward_kernel(const std::vector<std::int32_t>& axes,
stride_pre[i] = stride_pre[i-1]*inputDims[i-1];
}
const I* inputAccumulation = input;
I* outputAccumulation = nullptr;
// Type should be the return type of stableMean<I>(), which is always floating point
const decltype(stableMean<I>(input, 0, 0))* inputAccumulation = nullptr;
decltype(stableMean<I>(input, 0, 0))* outputAccumulation = nullptr;
for (const auto& axisInt : axes) {
const std::size_t a = static_cast<std::size_t>(axisInt);
......@@ -101,18 +114,23 @@ void ReduceMeanImpl_cpu_forward_kernel(const std::vector<std::int32_t>& axes,
for (std::size_t post = 0; post < stride_post[a]; ++post) {
const std::size_t idx_i = pre * dim_i * stride_post[a] + post;
const std::size_t idx_o = pre * stride_post[a] + post;
outputAccumulation[idx_o] = stableMean<I>(inputAccumulation + idx_i, dim_i, stride_post[a]);
if (inputAccumulation == nullptr) {
outputAccumulation[idx_o] = stableMean<I>(input + idx_i, dim_i, stride_post[a]);
}
else {
outputAccumulation[idx_o] = stableMean<I>(inputAccumulation + idx_i, dim_i, stride_post[a]);
}
}
}
std::for_each(stride_pre.get()+a+1, stride_pre.get()+nb_dims, [dim_i] (std::size_t& val) { val /= dim_i; });
if (inputAccumulation != input) {
if (inputAccumulation != nullptr) {
delete[] inputAccumulation;
}
inputAccumulation = outputAccumulation;
}
// Copy elements from inputAccumulation to output while dividing by divisor
std::copy(inputAccumulation, inputAccumulation + outputElements, output);
std::transform(inputAccumulation, inputAccumulation + outputElements, output,
[](auto value) { return castFromFloat<O>(value); });
if (outputAccumulation) {
delete[] outputAccumulation;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment