diff --git a/aidge_export_cpp/kernels/batchnorm.hpp b/aidge_export_cpp/kernels/batchnorm.hpp index 0260d42539a892150f1c13cd2a275bb579cd194a..01104f99be12b7035cf83135d956ad6855e32821 100644 --- a/aidge_export_cpp/kernels/batchnorm.hpp +++ b/aidge_export_cpp/kernels/batchnorm.hpp @@ -26,7 +26,14 @@ void batchnorm_forward ( const Rescaling_T& __restrict rescaling) { for (unsigned int output = 0; output < NB_OUTPUTS; ++output) { - const Output_T var = sqrt(variances[output] + epsilon); + // If the variance is 0, we need to avoid division by 0 + const Output_T var = epsilon; + + // If the variance is negative, we need to set it to 0 to avoid a sqrt of a negative number + if (variances[output] > 0.0) + { + var = sqrt(variances[output] + epsilon); + } for (int oy = 0; oy < OUTPUTS_HEIGHT; ++oy) { for (int ox = 0; ox < OUTPUTS_WIDTH; ++ox) {