From 203fee0dd7dc352f196b8e6668b69e76d2b3320a Mon Sep 17 00:00:00 2001 From: Gallasko <gallasko@gmail.com> Date: Fri, 4 Apr 2025 15:59:30 +0200 Subject: [PATCH] fix: Batch support for batchnorm --- aidge_export_cpp/kernels/batchnorm.hpp | 32 ++++++++----------- .../templates/configuration/_def_io.jinja | 2 ++ .../kernel_forward/batchnorm_forward.jinja | 3 +- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/aidge_export_cpp/kernels/batchnorm.hpp b/aidge_export_cpp/kernels/batchnorm.hpp index 092ed4d..f05a047 100644 --- a/aidge_export_cpp/kernels/batchnorm.hpp +++ b/aidge_export_cpp/kernels/batchnorm.hpp @@ -8,7 +8,7 @@ // WARNING: this kernel only works for 32-bits floating point values -template<int NB_OUTPUTS, +template<int NB_BATCHES, int NB_OUTPUTS, int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH, ActivationFunction_T ACTIVATION, typename Input_T, typename Output_T, @@ -25,23 +25,19 @@ void batchnorm_forward ( const double epsilon, const Rescaling_T& __restrict rescaling) { - for (unsigned int output = 0; output < NB_OUTPUTS; ++output) { - // If the variance is 0, we need to avoid division by 0 - Output_T var = epsilon; - - // If the variance is negative, we need to set it to 0 to avoid a sqrt of a negative number - if (variances[output] > 0.0) - { - var = sqrt(variances[output] + epsilon); - } - - for (int oy = 0; oy < OUTPUTS_HEIGHT; ++oy) { - for (int ox = 0; ox < OUTPUTS_WIDTH; ++ox) { - const int outputOffset = OUTPUTS_WIDTH * oy + ox; - - const Output_T normalized = (inputs[outputOffset + output] - means[output]) / var; - const Output_T sAs = scales[output] * normalized + biases[output]; - outputs[outputOffset + output] = activation_forward_value<Output_T>(sAs, output, ACTIVATION, rescaling); + for (unsigned int batch = 0; batch < NB_BATCHES; ++batch) { + for (unsigned int output = 0; output < NB_OUTPUTS; ++output) { + // If the variance is 0, we need to avoid division by 0 + Output_T var = sqrt(variances[output] > 0.0 ? variances[output] + epsilon : epsilon); + + for (int oy = 0; oy < OUTPUTS_HEIGHT; ++oy) { + for (int ox = 0; ox < OUTPUTS_WIDTH; ++ox) { + const int outputOffset = batch * OUTPUTS_WIDTH * OUTPUTS_HEIGHT * NB_OUTPUTS + output * OUTPUTS_WIDTH * OUTPUTS_HEIGHT + OUTPUTS_WIDTH * oy + ox; + + const Output_T normalized = (inputs[outputOffset] - means[output]) / var; + const Output_T sAs = scales[output] * normalized + biases[output]; + outputs[outputOffset] = activation_forward_value<Output_T>(sAs, output, ACTIVATION, rescaling); + } } } } diff --git a/aidge_export_cpp/templates/configuration/_def_io.jinja b/aidge_export_cpp/templates/configuration/_def_io.jinja index 66756cf..f444547 100644 --- a/aidge_export_cpp/templates/configuration/_def_io.jinja +++ b/aidge_export_cpp/templates/configuration/_def_io.jinja @@ -4,6 +4,7 @@ #define {{ in_name[inidx]|upper }}_NB_CHANNELS {{ in_chan[inidx] }} #define {{ in_name[inidx]|upper }}_IN_HEIGHT {{ in_height[inidx] }} #define {{ in_name[inidx]|upper }}_IN_WIDTH {{ in_width[inidx] }} +#define {{ in_name[inidx]|upper }}_IN_BATCH {{ in_batch[inidx] }} {% endfor %} // OUTPUT CONF @@ -11,4 +12,5 @@ #define {{ out_name[outidx]|upper }}_NB_OUTPUTS {{ out_chan[outidx] }} #define {{ out_name[outidx]|upper }}_OUT_HEIGHT {{ out_height[outidx] }} #define {{ out_name[outidx]|upper }}_OUT_WIDTH {{ out_width[outidx] }} +#define {{ out_name[outidx]|upper }}_OUT_BATCH {{ out_batch[outidx] }} {% endfor %} diff --git a/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja b/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja index 05e5154..03fd8e8 100644 --- a/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja +++ b/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja @@ -1,6 +1,7 @@ {% filter indent(width=4, first=False) %} {% include "./_mem_offset.jinja" %} -batchnorm_forward<{{ out_name[0]|upper }}_NB_OUTPUTS, +batchnorm_forward<{{ out_name[0]|upper }}_OUT_BATCH, + {{ out_name[0]|upper }}_NB_OUTPUTS, {{ out_name[0]|upper }}_OUT_HEIGHT, {{ out_name[0]|upper }}_OUT_WIDTH, {{name|upper}}_ACTIVATION> -- GitLab