From 203fee0dd7dc352f196b8e6668b69e76d2b3320a Mon Sep 17 00:00:00 2001
From: Gallasko <gallasko@gmail.com>
Date: Fri, 4 Apr 2025 15:59:30 +0200
Subject: [PATCH] fix: Batch support for batchnorm

---
 aidge_export_cpp/kernels/batchnorm.hpp        | 32 ++++++++-----------
 .../templates/configuration/_def_io.jinja     |  2 ++
 .../kernel_forward/batchnorm_forward.jinja    |  3 +-
 3 files changed, 18 insertions(+), 19 deletions(-)

diff --git a/aidge_export_cpp/kernels/batchnorm.hpp b/aidge_export_cpp/kernels/batchnorm.hpp
index 092ed4d..f05a047 100644
--- a/aidge_export_cpp/kernels/batchnorm.hpp
+++ b/aidge_export_cpp/kernels/batchnorm.hpp
@@ -8,7 +8,7 @@
 
 // WARNING: this kernel only works for 32-bits floating point values
 
-template<int NB_OUTPUTS,
+template<int NB_BATCHES, int NB_OUTPUTS,
          int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH,
          ActivationFunction_T ACTIVATION,
          typename Input_T, typename Output_T,
@@ -25,23 +25,19 @@ void batchnorm_forward (
     const double epsilon,
     const Rescaling_T& __restrict rescaling)
 {
-    for (unsigned int output = 0; output < NB_OUTPUTS; ++output) {
-        // If the variance is 0, we need to avoid division by 0
-        Output_T var = epsilon;
-
-        // If the variance is negative, we need to set it to 0 to avoid a sqrt of a negative number
-        if (variances[output] > 0.0)
-        {
-            var = sqrt(variances[output] + epsilon);
-        }
-
-        for (int oy = 0; oy < OUTPUTS_HEIGHT; ++oy) {
-            for (int ox = 0; ox < OUTPUTS_WIDTH; ++ox) {
-                const int outputOffset = OUTPUTS_WIDTH * oy + ox;
-
-                const Output_T normalized = (inputs[outputOffset + output] - means[output]) / var;
-                const Output_T sAs = scales[output] * normalized + biases[output];
-                outputs[outputOffset + output] = activation_forward_value<Output_T>(sAs, output, ACTIVATION, rescaling);
+    for (unsigned int batch = 0; batch < NB_BATCHES; ++batch) {
+        for (unsigned int output = 0; output < NB_OUTPUTS; ++output) {
+            // If the variance is 0, we need to avoid division by 0
+            Output_T var = sqrt(variances[output] > 0.0 ? variances[output] + epsilon : epsilon);
+
+            for (int oy = 0; oy < OUTPUTS_HEIGHT; ++oy) {
+                for (int ox = 0; ox < OUTPUTS_WIDTH; ++ox) {
+                    const int outputOffset = batch * OUTPUTS_WIDTH * OUTPUTS_HEIGHT * NB_OUTPUTS + output * OUTPUTS_WIDTH * OUTPUTS_HEIGHT + OUTPUTS_WIDTH * oy + ox;
+
+                    const Output_T normalized = (inputs[outputOffset] - means[output]) / var;
+                    const Output_T sAs = scales[output] * normalized + biases[output];
+                    outputs[outputOffset] = activation_forward_value<Output_T>(sAs, output, ACTIVATION, rescaling);
+                }
             }
         }
     }
diff --git a/aidge_export_cpp/templates/configuration/_def_io.jinja b/aidge_export_cpp/templates/configuration/_def_io.jinja
index 66756cf..f444547 100644
--- a/aidge_export_cpp/templates/configuration/_def_io.jinja
+++ b/aidge_export_cpp/templates/configuration/_def_io.jinja
@@ -4,6 +4,7 @@
 #define {{ in_name[inidx]|upper }}_NB_CHANNELS {{ in_chan[inidx] }}
 #define {{ in_name[inidx]|upper }}_IN_HEIGHT {{ in_height[inidx] }}
 #define {{ in_name[inidx]|upper }}_IN_WIDTH {{ in_width[inidx] }}
+#define {{ in_name[inidx]|upper }}_IN_BATCH {{ in_batch[inidx] }}
 {% endfor %}
 
 // OUTPUT CONF
@@ -11,4 +12,5 @@
 #define {{ out_name[outidx]|upper }}_NB_OUTPUTS {{ out_chan[outidx] }}
 #define {{ out_name[outidx]|upper }}_OUT_HEIGHT {{ out_height[outidx] }}
 #define {{ out_name[outidx]|upper }}_OUT_WIDTH {{ out_width[outidx] }}
+#define {{ out_name[outidx]|upper }}_OUT_BATCH {{ out_batch[outidx] }}
 {% endfor %}
diff --git a/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja b/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja
index 05e5154..03fd8e8 100644
--- a/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja
+++ b/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja
@@ -1,6 +1,7 @@
 {% filter indent(width=4, first=False) %}
 {% include "./_mem_offset.jinja" %}
-batchnorm_forward<{{ out_name[0]|upper }}_NB_OUTPUTS,
+batchnorm_forward<{{ out_name[0]|upper }}_OUT_BATCH,
+                  {{ out_name[0]|upper }}_NB_OUTPUTS,
                   {{ out_name[0]|upper }}_OUT_HEIGHT,
                   {{ out_name[0]|upper }}_OUT_WIDTH,
                   {{name|upper}}_ACTIVATION>
-- 
GitLab