From 52d701379ec57736e7380df01e1c4f080c2c98a2 Mon Sep 17 00:00:00 2001
From: Gallasko <gallasko@gmail.com>
Date: Mon, 24 Mar 2025 10:12:55 +0100
Subject: [PATCH] feat: Added Batchnorm2d export op

---
 aidge_export_cpp/kernels/batchnorm.hpp        | 11 ++++++----
 aidge_export_cpp/operators.py                 | 22 ++++++++++++++++++-
 .../configuration/batchnorm_config.jinja      |  1 +
 .../kernel_forward/batchnorm_forward.jinja    |  2 +-
 4 files changed, 30 insertions(+), 6 deletions(-)

diff --git a/aidge_export_cpp/kernels/batchnorm.hpp b/aidge_export_cpp/kernels/batchnorm.hpp
index 740ea21..0260d42 100644
--- a/aidge_export_cpp/kernels/batchnorm.hpp
+++ b/aidge_export_cpp/kernels/batchnorm.hpp
@@ -2,7 +2,8 @@
 #define __AIDGE_EXPORT_CPP_KERNELS_BATCHNORM__
 
 #include "network/typedefs.hpp"
-#include "kernels/rescaling.hpp"
+#include "kernels/activation.hpp"
+
 #include <math.h>
 
 // WARNING: this kernel only works for 32-bits floating point values
@@ -11,7 +12,8 @@ template<int NB_OUTPUTS,
          int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH,
          ActivationFunction_T ACTIVATION,
          typename Input_T, typename Output_T,
-         typename Param_T>
+         typename Param_T,
+         typename Rescaling_T>
 __attribute__((always_inline)) inline
 void batchnorm_forward (
     const Input_T* __restrict inputs,
@@ -20,7 +22,8 @@ void batchnorm_forward (
     const Param_T* __restrict variances,
     const Param_T* __restrict means,
     const Param_T* __restrict scales,
-    const double epsilon)
+    const double epsilon,
+    const Rescaling_T& __restrict rescaling)
 {
     for (unsigned int output = 0; output < NB_OUTPUTS; ++output) {
         const Output_T var = sqrt(variances[output] + epsilon);
@@ -31,7 +34,7 @@ void batchnorm_forward (
 
                 const Output_T normalized = (inputs[outputOffset + output] - means[output]) / var;
                 const Output_T sAs = scales[output] * normalized + biases[output];
-                outputs[outputOffset + output] = sat<Output_T>(sAs, output, ACTIVATION, NoScaling);
+                outputs[outputOffset + output] = activation_forward_value<Output_T>(sAs, output, ACTIVATION, rescaling);
             }
         }
     }
diff --git a/aidge_export_cpp/operators.py b/aidge_export_cpp/operators.py
index 0f6d3c8..b7e5472 100644
--- a/aidge_export_cpp/operators.py
+++ b/aidge_export_cpp/operators.py
@@ -317,4 +317,24 @@ class SoftmaxCPP(ExportNodeCpp):
         self.kernels_to_copy = [
             str(ROOT / "kernels" / "softmax.hpp"),
             str(ROOT / "kernels" / "macs.hpp"),
-        ]
\ No newline at end of file
+        ]
+
+@ExportLibCpp.register("BatchNorm2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
+class BatchNorm2DCPP(ExportNodeCpp):
+    def __init__(self, node, mem_info):
+        super().__init__(node, mem_info)
+        self.attributes["activation"] = "Linear"
+        self.attributes["rescaling"] = "NoScaling"
+        self.attributes["epsilon"] = node.get_operator().attr.epsilon
+        self.config_template = str(
+            ROOT / "templates" / "configuration" / "batchnorm_config.jinja")
+        self.forward_template = str(
+            ROOT / "templates" / "kernel_forward" / "batchnorm_forward.jinja")
+        self.include_list = []
+        self.kernels_to_copy = [
+            str(ROOT / "kernels" / "batchnorm.hpp"),
+            str(ROOT / "kernels" / "macs.hpp"),
+            str(ROOT / "kernels" / "activation.hpp"),
+            str(ROOT / "kernels" / "rescaling.hpp")
+        ]
+
diff --git a/aidge_export_cpp/templates/configuration/batchnorm_config.jinja b/aidge_export_cpp/templates/configuration/batchnorm_config.jinja
index 701ba7c..bc01e3b 100644
--- a/aidge_export_cpp/templates/configuration/batchnorm_config.jinja
+++ b/aidge_export_cpp/templates/configuration/batchnorm_config.jinja
@@ -7,5 +7,6 @@
 {% include "./_meminfo.jinja" %}
 #define {{ name|upper }}_ACTIVATION {{ activation }}
 #define {{ name|upper }}_EPSILON {{ epsilon }}
+static const {{ rescaling }} {{ name|upper }}_RESCALING = {};
 
 #endif /* {{ name|upper }}_LAYER_H */
diff --git a/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja b/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja
index 5a759b8..05e5154 100644
--- a/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja
+++ b/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja
@@ -4,6 +4,6 @@ batchnorm_forward<{{ out_name[0]|upper }}_NB_OUTPUTS,
                   {{ out_name[0]|upper }}_OUT_HEIGHT,
                   {{ out_name[0]|upper }}_OUT_WIDTH,
                   {{name|upper}}_ACTIVATION>
-                  ({{in_name[0]}}, {{out_name[0]}}, {{in_name[1]}}, {{in_name[2]}}, {{in_name[3]}}, {{in_name[4]}}, {{name|upper}}_EPSILON);
+                  ({{in_name[0]}}, {{out_name[0]}}, {{in_name[1]}}, {{in_name[2]}}, {{in_name[3]}}, {{in_name[4]}}, {{name|upper}}_EPSILON, {{name|upper}}_RESCALING);
 {% include "./_save_outputs.jinja" %}
 {% endfilter %}
-- 
GitLab