diff --git a/.gitignore b/.gitignore
index 67ffbefbdc41ea1abebd64602649fb129f2faf07..93bcfd30700409d495c7a6d4eb19c12636afbda8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -14,6 +14,9 @@ dist*/
 aidge_export_cpp/_version.py
 wheelhouse/*
 
+# Temp test folders
+aidge_export_cpp/unit_tests/*_temp_test
+
 # Mermaid
 *.mmd
 
diff --git a/aidge_export_cpp/kernels/batchnorm.hpp b/aidge_export_cpp/kernels/batchnorm.hpp
index 740ea21e6f66ba338985db4f724a5d57377e1f81..f05a047511e12f895ef88be0e402b89e5197432b 100644
--- a/aidge_export_cpp/kernels/batchnorm.hpp
+++ b/aidge_export_cpp/kernels/batchnorm.hpp
@@ -2,16 +2,18 @@
 #define __AIDGE_EXPORT_CPP_KERNELS_BATCHNORM__
 
 #include "network/typedefs.hpp"
-#include "kernels/rescaling.hpp"
+#include "kernels/activation.hpp"
+
 #include <math.h>
 
 // WARNING: this kernel only works for 32-bits floating point values
 
-template<int NB_OUTPUTS,
+template<int NB_BATCHES, int NB_OUTPUTS,
          int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH,
          ActivationFunction_T ACTIVATION,
          typename Input_T, typename Output_T,
-         typename Param_T>
+         typename Param_T,
+         typename Rescaling_T>
 __attribute__((always_inline)) inline
 void batchnorm_forward (
     const Input_T* __restrict inputs,
@@ -20,18 +22,22 @@ void batchnorm_forward (
     const Param_T* __restrict variances,
     const Param_T* __restrict means,
     const Param_T* __restrict scales,
-    const double epsilon)
+    const double epsilon,
+    const Rescaling_T& __restrict rescaling)
 {
-    for (unsigned int output = 0; output < NB_OUTPUTS; ++output) {
-        const Output_T var = sqrt(variances[output] + epsilon);
+    for (unsigned int batch = 0; batch < NB_BATCHES; ++batch) {
+        for (unsigned int output = 0; output < NB_OUTPUTS; ++output) {
+            // If the variance is 0, we need to avoid division by 0
+            Output_T var = sqrt(variances[output] > 0.0 ? variances[output] + epsilon : epsilon);
 
-        for (int oy = 0; oy < OUTPUTS_HEIGHT; ++oy) {
-            for (int ox = 0; ox < OUTPUTS_WIDTH; ++ox) {
-                const int outputOffset = OUTPUTS_HEIGHT * oy + ox;
+            for (int oy = 0; oy < OUTPUTS_HEIGHT; ++oy) {
+                for (int ox = 0; ox < OUTPUTS_WIDTH; ++ox) {
+                    const int outputOffset = batch * OUTPUTS_WIDTH * OUTPUTS_HEIGHT * NB_OUTPUTS + output * OUTPUTS_WIDTH * OUTPUTS_HEIGHT + OUTPUTS_WIDTH * oy + ox;
 
-                const Output_T normalized = (inputs[outputOffset + output] - means[output]) / var;
-                const Output_T sAs = scales[output] * normalized + biases[output];
-                outputs[outputOffset + output] = sat<Output_T>(sAs, output, ACTIVATION, NoScaling);
+                    const Output_T normalized = (inputs[outputOffset] - means[output]) / var;
+                    const Output_T sAs = scales[output] * normalized + biases[output];
+                    outputs[outputOffset] = activation_forward_value<Output_T>(sAs, output, ACTIVATION, rescaling);
+                }
             }
         }
     }
diff --git a/aidge_export_cpp/kernels/concat.hpp b/aidge_export_cpp/kernels/concat.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..dde8c4fc3a9ce9eea5d4ae4cfad35c078f60450d
--- /dev/null
+++ b/aidge_export_cpp/kernels/concat.hpp
@@ -0,0 +1,39 @@
+#ifndef __AIDGE_EXPORT_CPP_KERNELS_CONCAT__
+#define __AIDGE_EXPORT_CPP_KERNELS_CONCAT__
+
+template<int AXIS_SIZE_POST,
+         int AXIS_SIZE_PRE,
+         unsigned int NB_INPUTS,
+         typename T>
+__attribute__((always_inline)) inline static
+void concat_forward (
+    const T* const * __restrict inputs,
+    const unsigned int* __restrict sizes,
+    T* __restrict output)
+{
+    unsigned int total_concat_axis_size = 0;
+    for (unsigned int n = 0; n < NB_INPUTS; ++n)
+        total_concat_axis_size += sizes[n];
+
+    for (int i = 0; i < AXIS_SIZE_PRE; ++i) {
+        // Loop over post-axis (e.g., dims after axis 1)
+        for (int j = 0; j < AXIS_SIZE_POST; ++j) {
+            unsigned int axis_offset = 0;
+
+            // Loop over each input tensor
+            for (unsigned int n = 0; n < NB_INPUTS; ++n) {
+                for (unsigned int k = 0; k < sizes[n]; ++k) {
+                    const int input_idx  = i * sizes[n] * AXIS_SIZE_POST + k * AXIS_SIZE_POST + j;
+
+                    output[i * total_concat_axis_size * AXIS_SIZE_POST + (axis_offset + k) * AXIS_SIZE_POST + j] =
+                        inputs[n][input_idx];
+                }
+
+                axis_offset += sizes[n];  // move along axis in output
+            }
+        }
+    }
+
+}
+
+#endif  // __AIDGE_EXPORT_CPP_KERNELS_CONCAT__
\ No newline at end of file
diff --git a/aidge_export_cpp/kernels/pad.hpp b/aidge_export_cpp/kernels/pad.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..4e83257c1152b1963dd4b0eefc912216a729de7d
--- /dev/null
+++ b/aidge_export_cpp/kernels/pad.hpp
@@ -0,0 +1,51 @@
+#ifndef __AIDGE_EXPORT_CPP_KERNELS_PAD2D__
+#define __AIDGE_EXPORT_CPP_KERNELS_PAD2D__
+
+#include "network/typedefs.hpp"
+#include "network/utils.hpp"
+
+// Todo add border value and border type (Reflect, Constant, Wrap...) and add the two missing pad value (bottom and right)
+
+template<int NB_BATCHES, int NB_CHANNELS,
+         int CHANNELS_HEIGHT, int CHANNELS_WIDTH,
+         int NB_OUTPUTS,
+         int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH,
+         int PADDING_TOP,
+         int PADDING_LEFT,
+         int PADDING_BOTTOM,
+         int PADDING_RIGHT,
+         typename Input_T, typename Output_T>
+__attribute__((always_inline)) inline
+void pad_forward(
+    double borderValue,
+    const Input_T* __restrict inputs,
+    Output_T* __restrict outputs
+    )
+{
+    const unsigned int oySize = CHANNELS_HEIGHT + PADDING_TOP + PADDING_BOTTOM;
+    const unsigned int oxSize = CHANNELS_WIDTH + PADDING_LEFT + PADDING_RIGHT;
+
+    for (unsigned int batch = 0; batch < NB_BATCHES; ++batch) {
+        for (unsigned int ch = 0; ch < NB_CHANNELS; ++ch) {
+            const unsigned int preIndex = batch * NB_CHANNELS * CHANNELS_HEIGHT * CHANNELS_WIDTH + ch * CHANNELS_HEIGHT * CHANNELS_WIDTH;
+
+            for (unsigned int oy = 0; oy < oySize; ++oy) {
+                for (unsigned int ox = 0; ox < oxSize; ++ox) {
+                    const unsigned int outIndex = batch * NB_CHANNELS * oySize * oxSize + ch * oySize * oxSize + oy * oxSize + ox;
+
+                    outputs[outIndex] = borderValue;
+
+                    const unsigned int inputX = ox - PADDING_LEFT;
+                    const unsigned int inputY = oy - PADDING_TOP;
+
+                    if (inputY >= 0 and inputY < CHANNELS_HEIGHT and inputX >= 0 and inputX < CHANNELS_WIDTH)
+                    {
+                        outputs[outIndex] = inputs[preIndex + inputY * CHANNELS_WIDTH + inputX];
+                    }
+                }
+            }
+        }
+    }
+}
+
+#endif  // __AIDGE_EXPORT_CPP_KERNELS_PAD2D__
diff --git a/aidge_export_cpp/kernels/pooling.hpp b/aidge_export_cpp/kernels/pooling.hpp
index 478b6a58aed45e2bce0ed1683ad113f9c7a8bffb..a86fd4196a9f6e19f45dbdc4f1035c1e94e7d285 100644
--- a/aidge_export_cpp/kernels/pooling.hpp
+++ b/aidge_export_cpp/kernels/pooling.hpp
@@ -7,7 +7,7 @@
 #include <stdexcept>
 
 
-template<int NB_CHANNELS, 
+template<int NB_CHANNELS,
          int CHANNELS_HEIGHT, int CHANNELS_WIDTH,
          int NB_OUTPUTS,
          int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH,
@@ -17,7 +17,7 @@ template<int NB_CHANNELS,
          Pooling_T POOLING_TYPE,
          ActivationFunction_T ACTIVATION,
          typename Input_T, typename Output_T>
-__attribute__((always_inline)) inline 
+__attribute__((always_inline)) inline
 void pooling_forward(
     const Input_T* __restrict inputs,
     Output_T* __restrict outputs)
@@ -32,7 +32,7 @@ void pooling_forward(
             : max(PADDING_Y - (oy * STRIDE_Y), 0);
         const int syMax = (PADDING_Y == 0
                 && OUTPUTS_HEIGHT == OUTPUTS_HEIGHT_NOPAD) ? POOL_HEIGHT
-            : clamp(CHANNELS_HEIGHT + PADDING_Y - (oy * STRIDE_Y), 
+            : clamp(CHANNELS_HEIGHT + PADDING_Y - (oy * STRIDE_Y),
                     0, POOL_HEIGHT);
         const int iy = (oy * STRIDE_Y) - PADDING_Y;
 
@@ -45,7 +45,7 @@ void pooling_forward(
                 const int sxMax = (PADDING_X == 0
                         && OUTPUTS_WIDTH == OUTPUTS_WIDTH_NOPAD)
                             ? POOL_WIDTH
-                    : clamp(CHANNELS_WIDTH + PADDING_X - (ox * STRIDE_X), 
+                    : clamp(CHANNELS_WIDTH + PADDING_X - (ox * STRIDE_X),
                             0, POOL_WIDTH);
                 const int ix = (ox * STRIDE_X) - PADDING_X;
 
@@ -86,7 +86,7 @@ void pooling_forward(
                     outputs[oOffset + output] = maxVal;
                 }
                 else if (POOLING_TYPE == Average) {
-                    int32_t sum = 0;
+                    Output_T sum = 0;
 
                     for (int sy = 0; sy < POOL_HEIGHT; ++sy) {
                         if ((PADDING_Y != 0
diff --git a/aidge_export_cpp/kernels/softmax.hpp b/aidge_export_cpp/kernels/softmax.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..f5472cf6d807bc2f547e58616943f6e72dccd80e
--- /dev/null
+++ b/aidge_export_cpp/kernels/softmax.hpp
@@ -0,0 +1,53 @@
+#ifndef __AIDGE_EXPORT_CPP_KERNELS_SOFTMAX__
+#define __AIDGE_EXPORT_CPP_KERNELS_SOFTMAX__
+
+#include "network/typedefs.hpp"
+#include "network/utils.hpp"
+#include "kernels/macs.hpp"
+
+#include <type_traits>
+#include <cmath>
+#include <algorithm>
+
+template<int AXIS_SIZE,
+         int AXIS_SIZE_POST,
+         int AXIS_SIZE_PRE,
+         typename Input_T, typename Output_T>
+__attribute__((always_inline)) inline
+void softmax_forward (
+    const Input_T* __restrict inputs,
+    Output_T* __restrict outputs)
+{
+    // Iterate over the "pre-axis" and "post-axis" slices.
+    // For each slice along the axis, compute the maximum value,
+    // the sum of exponentials, and then write the normalized softmax outputs.
+    for (int i = 0; i < AXIS_SIZE_PRE; ++i) {
+        for (int j = 0; j < AXIS_SIZE_POST; ++j) {
+            // Compute the base index for this slice.
+            const int baseIdx = i * AXIS_SIZE * AXIS_SIZE_POST + j;
+
+            // Find the maximum value along the axis.
+            Input_T maxVal = inputs[baseIdx];
+            for (int k = 1; k < AXIS_SIZE; ++k) {
+                const int idx = baseIdx + k * AXIS_SIZE_POST;
+                maxVal = std::max(maxVal, inputs[idx]);
+            }
+
+            // Compute the sum of the exponentials along the axis.
+            Input_T sumExp = 0;
+            for (int k = 0; k < AXIS_SIZE; ++k) {
+                const int idx = baseIdx + k * AXIS_SIZE_POST;
+                outputs[idx] = std::exp(inputs[idx] - maxVal);
+                sumExp += outputs[idx];
+            }
+
+            // Write the softmax values to the output.
+            for (int k = 0; k < AXIS_SIZE; ++k) {
+                const int idx = baseIdx + k * AXIS_SIZE_POST;
+                outputs[idx] /= sumExp;
+            }
+        }
+    }
+}
+
+#endif  // __AIDGE_EXPORT_CPP_KERNELS_SOFTMAX__
diff --git a/aidge_export_cpp/operators.py b/aidge_export_cpp/operators.py
index 346928f4a84c403df2172311cede8b99fd06eebe..26ca62155401707573d9625ad91a9b63cb1b4d2b 100644
--- a/aidge_export_cpp/operators.py
+++ b/aidge_export_cpp/operators.py
@@ -73,10 +73,25 @@ class ProducerCPP(ExportNode):
 
 # TODO : find a way to remove this dummy exportnode
 @ExportLibCpp.register("Pad2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.any)))
-class Pad_ARMCortexM(ExportNodeCpp):
+class PadCPP(ExportNodeCpp):
     def __init__(self, node, mem_info):
-        raise NotImplementedError("Pad2D nodes is not implemented")
+        super().__init__(node, mem_info)
+        self.attributes["padding"] = node.get_operator().attr.begin_end_borders
+        self.attributes["border_type"] = node.get_operator().attr.border_type
+        self.attributes["border_value"] = node.get_operator().attr.border_value
+
+        assert self.attributes["border_type"] == aidge_core.pad_border_type.Constant, (
+            f"export Pad2d: border_type == {node.get_operator().attr.border_type} not implemented"
+        )
 
+        self.config_template = str(
+            ROOT / "templates" / "configuration" / "pad_config.jinja")
+        self.forward_template = str(
+            ROOT / "templates" / "kernel_forward" / "pad_forward.jinja")
+        self.include_list = []
+        self.kernels_to_copy = [
+            str(ROOT / "kernels" / "pad.hpp")
+        ]
 
 @ExportLibCpp.register("ReLU", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
 class ReLUCPP(ExportNodeCpp):
@@ -237,6 +252,20 @@ class MaxPoolCPP(ExportNodeCpp):
 
         _setup_pooling(self)
 
+@ExportLibCpp.register("AvgPooling2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
+class AvgPoolCPP(ExportNodeCpp):
+    def __init__(self, node, mem_info):
+        super().__init__(node, mem_info)
+
+        # No padding with MaxPooling
+        # Use PaddedMaxPooling to add padding attribute
+        self.attributes["padding"] = [0, 0]
+        self.attributes["pool_type"] = "Average"
+        self.attributes["activation"] = "Linear"
+        self.attributes["rescaling"] = "NoScaling"
+
+        _setup_pooling(self)
+
 @ExportLibCpp.register_metaop("PaddedMaxPooling2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
 class PaddedMaxPoolCPP(ExportNodeCpp):
     def __init__(self, node, mem_info):
@@ -302,4 +331,117 @@ class TransposeCPP(ExportNodeCpp):
         self.include_list = []
         self.kernels_to_copy = [
             str(ROOT / "kernels" / "transpose.hpp")
+        ]
+
+@ExportLibCpp.register("Softmax", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
+class SoftmaxCPP(ExportNodeCpp):
+    def __init__(self, node, mem_info):
+        super().__init__(node, mem_info)
+        assert self.node.get_nb_inputs() == 1, (
+            f"export softmax: nb_inputs == {self.node.get_nb_inputs()} not implemented"
+        )
+
+        tensor = self.operator.get_input(0)
+        nbDims = len(tensor.dims())
+        axis = node.get_operator().attr.axis if node.get_operator().attr.axis >= 0 else node.get_operator().attr.axis + nbDims
+
+        assert axis < nbDims, (
+            f"export softmax: attribute axis == {node.get_operator().attr.axis} should be less than {nbDims}"
+        )
+
+        postAxisElems = 1
+        for i in range(axis + 1, nbDims):
+            postAxisElems *= tensor.dims()[i]
+
+        preAxisElems = 1
+        for i in range(axis):
+            preAxisElems *= tensor.dims()[i]
+
+        self.attributes["axis_size"] = tensor.dims()[axis]
+        self.attributes["axis_size_post"] = postAxisElems
+        self.attributes["axis_size_pre"] = preAxisElems
+
+        self.config_template = str(
+            ROOT / "templates" / "configuration" / "softmax_config.jinja")
+        self.forward_template = str(
+            ROOT / "templates" / "kernel_forward" / "softmax_forward.jinja")
+        self.include_list = []
+        self.kernels_to_copy = [
+            str(ROOT / "kernels" / "softmax.hpp"),
+            str(ROOT / "kernels" / "macs.hpp"),
+        ]
+
+@ExportLibCpp.register("BatchNorm2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
+class BatchNorm2DCPP(ExportNodeCpp):
+    def __init__(self, node, mem_info):
+        super().__init__(node, mem_info)
+        self.attributes["activation"] = "Linear"
+        self.attributes["rescaling"] = "NoScaling"
+        self.attributes["epsilon"] = node.get_operator().attr.epsilon
+        self.config_template = str(
+            ROOT / "templates" / "configuration" / "batchnorm_config.jinja")
+        self.forward_template = str(
+            ROOT / "templates" / "kernel_forward" / "batchnorm_forward.jinja")
+        self.include_list = []
+        self.kernels_to_copy = [
+            str(ROOT / "kernels" / "batchnorm.hpp"),
+            str(ROOT / "kernels" / "macs.hpp"),
+            str(ROOT / "kernels" / "activation.hpp"),
+            str(ROOT / "kernels" / "rescaling.hpp")
+        ]
+
+@ExportLibCpp.register("Concat", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
+class Concat(ExportNodeCpp):
+    def __init__(self, node, mem_info):
+        super().__init__(node, mem_info)
+        assert self.node.get_nb_inputs() >= 1, (
+            f"export softmax: nb_inputs == {self.node.get_nb_inputs()} not implemented"
+        )
+
+        inputIndex = 0
+
+        tensor = self.operator.get_input(0)
+        for idx, _ in enumerate(self.node.inputs()):
+            if self.operator.get_input(idx) is not None:
+                tensor = self.operator.get_input(idx)
+                nbDims = len(tensor.dims())
+                axis = node.get_operator().attr.axis if node.get_operator().attr.axis >= 0 else node.get_operator().attr.axis + nbDims
+
+                assert axis < nbDims, (
+                    f"export softmax: attribute axis == {axis} should be less than {nbDims}"
+                )
+
+                postAxisElems = 1
+                for i in range(axis + 1, nbDims):
+                    postAxisElems *= tensor.dims()[i]
+
+                preAxisElems = 1
+                for i in range(axis):
+                    preAxisElems *= tensor.dims()[i]
+
+                if (inputIndex == 0):
+                    self.attributes["axis_size_post"] = postAxisElems
+                    self.attributes["axis_size_pre"] = preAxisElems
+
+                    self.attributes["axis_size"] = [None] * self.attributes["nb_in"]
+                else:
+                    assert self.attributes["axis_size_post"] == postAxisElems, (
+                        f"export concat: axis_size_post {self.attributes['axis_size_post']} != {postAxisElems}"
+                    )
+                    assert self.attributes["axis_size_pre"] == preAxisElems, (
+                        f"export concat: axis_size_pre {self.attributes['axis_size_pre']} != {preAxisElems}"
+                    )
+
+                self.attributes["axis_size"][idx] = tensor.dims()[axis]
+            else:
+                assert false, (
+                    f"export concat: input {idx} is None, not implemented")
+
+            inputIndex += 1
+
+        self.config_template = str(ROOT / "templates" / "configuration" / "concat_config.jinja")
+        self.forward_template = str(ROOT / "templates" / "kernel_forward" / "concat_forward.jinja")
+        self.include_list = []
+        self.kernels_to_copy = [
+            str(ROOT / "kernels" / "concat.hpp"),
         ]
\ No newline at end of file
diff --git a/aidge_export_cpp/templates/configuration/_def_io.jinja b/aidge_export_cpp/templates/configuration/_def_io.jinja
index 66756cf8f501035f7222272f9c410908f499f06f..f44454769bc66e5d15e93834b28e088525930271 100644
--- a/aidge_export_cpp/templates/configuration/_def_io.jinja
+++ b/aidge_export_cpp/templates/configuration/_def_io.jinja
@@ -4,6 +4,7 @@
 #define {{ in_name[inidx]|upper }}_NB_CHANNELS {{ in_chan[inidx] }}
 #define {{ in_name[inidx]|upper }}_IN_HEIGHT {{ in_height[inidx] }}
 #define {{ in_name[inidx]|upper }}_IN_WIDTH {{ in_width[inidx] }}
+#define {{ in_name[inidx]|upper }}_IN_BATCH {{ in_batch[inidx] }}
 {% endfor %}
 
 // OUTPUT CONF
@@ -11,4 +12,5 @@
 #define {{ out_name[outidx]|upper }}_NB_OUTPUTS {{ out_chan[outidx] }}
 #define {{ out_name[outidx]|upper }}_OUT_HEIGHT {{ out_height[outidx] }}
 #define {{ out_name[outidx]|upper }}_OUT_WIDTH {{ out_width[outidx] }}
+#define {{ out_name[outidx]|upper }}_OUT_BATCH {{ out_batch[outidx] }}
 {% endfor %}
diff --git a/aidge_export_cpp/templates/configuration/batchnorm_config.jinja b/aidge_export_cpp/templates/configuration/batchnorm_config.jinja
index 701ba7c46e4727eca86fcabf3ed997cab69f4e92..ae7ef5760a63689d11f6d7369e387b55b7cb3d15 100644
--- a/aidge_export_cpp/templates/configuration/batchnorm_config.jinja
+++ b/aidge_export_cpp/templates/configuration/batchnorm_config.jinja
@@ -1,11 +1,13 @@
 {#- For name header -#}
 #ifndef {{ name|upper }}_LAYER_H
 #define {{ name|upper }}_LAYER_H
+#include "kernels/rescaling.hpp"
 
 {# For layer configuration -#}
 {% include "./_def_io.jinja" %}
 {% include "./_meminfo.jinja" %}
 #define {{ name|upper }}_ACTIVATION {{ activation }}
 #define {{ name|upper }}_EPSILON {{ epsilon }}
+static const {{ rescaling }} {{ name|upper }}_RESCALING = {};
 
 #endif /* {{ name|upper }}_LAYER_H */
diff --git a/aidge_export_cpp/templates/configuration/concat_config.jinja b/aidge_export_cpp/templates/configuration/concat_config.jinja
new file mode 100644
index 0000000000000000000000000000000000000000..ea8246db9a315a371e0cacea5d45d07fa2b8f7e8
--- /dev/null
+++ b/aidge_export_cpp/templates/configuration/concat_config.jinja
@@ -0,0 +1,18 @@
+{#- For name header -#}
+#ifndef {{ name|upper }}_LAYER_H
+#define {{ name|upper }}_LAYER_H
+
+{% include "./_def_io.jinja" %}
+{% include "./_meminfo.jinja" %}
+
+// Attributes
+#define {{ name|upper }}_NB_INPUTS {{ nb_in }}
+#define {{ name|upper }}_AXIS {{ axis }}
+{%- for i in range(nb_in) %}
+#define {{ name|upper }}_INPUT_{{i}}_SIZE {{ axis_size[i] }}
+{%- endfor %}
+
+#define {{ name|upper }}_AXIS_SIZE_POST {{ axis_size_post }}
+#define {{ name|upper }}_AXIS_SIZE_PRE {{ axis_size_pre }}
+
+#endif /* {{ name|upper }}_LAYER_H */
diff --git a/aidge_export_cpp/templates/configuration/pad_config.jinja b/aidge_export_cpp/templates/configuration/pad_config.jinja
new file mode 100644
index 0000000000000000000000000000000000000000..8b21577fe4d6f52ddb36ae796740f265db3d45cc
--- /dev/null
+++ b/aidge_export_cpp/templates/configuration/pad_config.jinja
@@ -0,0 +1,13 @@
+{#- For name header -#}
+#ifndef {{ name|upper }}_LAYER_H
+#define {{ name|upper }}_LAYER_H
+{# For layer configuration -#}
+{% include "./_def_io.jinja" %}
+{% include "./_meminfo.jinja" %}
+#define {{ name|upper }}_PADDING_BOTTOM {{ padding[2] }}
+#define {{ name|upper }}_PADDING_RIGHT {{ padding[3] }}
+#define {{ name|upper }}_PADDING_TOP {{ padding[0] }}
+#define {{ name|upper }}_PADDING_LEFT {{ padding[1] }}
+#define {{ name|upper }}_BORDER_VALUE {{ border_value }}
+
+#endif /* {{ name|upper }}_LAYER_H */
diff --git a/aidge_export_cpp/templates/configuration/softmax_config.jinja b/aidge_export_cpp/templates/configuration/softmax_config.jinja
new file mode 100644
index 0000000000000000000000000000000000000000..e9661bc553bfefb5a0fb12be5fe87106ac90e4a9
--- /dev/null
+++ b/aidge_export_cpp/templates/configuration/softmax_config.jinja
@@ -0,0 +1,14 @@
+{#- For name header -#}
+#ifndef {{ name|upper }}_LAYER_H
+#define {{ name|upper }}_LAYER_H
+{# For layer configuration -#}
+{% include "./_def_io.jinja" %}
+{% include "./_meminfo.jinja" %}
+
+{#- Calculate sizes #}
+{%- set weights_size = out_chan[0] * in_chan[0] * in_height[0] * in_width[0] %}
+#define {{ name|upper }}_AXIS_SIZE {{ axis_size }}
+#define {{ name|upper }}_AXIS_SIZE_POST {{ axis_size_post }}
+#define {{ name|upper }}_AXIS_SIZE_PRE {{ axis_size_pre }}
+
+#endif /* {{ name|upper }}_LAYER_H */
diff --git a/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja b/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja
index 5a759b839cd0b04b3b82f8ca4cb8dd1b0201f4f7..03fd8e89921bfa27f4eeb33b05a47b40329fa5de 100644
--- a/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja
+++ b/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja
@@ -1,9 +1,10 @@
 {% filter indent(width=4, first=False) %}
 {% include "./_mem_offset.jinja" %}
-batchnorm_forward<{{ out_name[0]|upper }}_NB_OUTPUTS,
+batchnorm_forward<{{ out_name[0]|upper }}_OUT_BATCH,
+                  {{ out_name[0]|upper }}_NB_OUTPUTS,
                   {{ out_name[0]|upper }}_OUT_HEIGHT,
                   {{ out_name[0]|upper }}_OUT_WIDTH,
                   {{name|upper}}_ACTIVATION>
-                  ({{in_name[0]}}, {{out_name[0]}}, {{in_name[1]}}, {{in_name[2]}}, {{in_name[3]}}, {{in_name[4]}}, {{name|upper}}_EPSILON);
+                  ({{in_name[0]}}, {{out_name[0]}}, {{in_name[1]}}, {{in_name[2]}}, {{in_name[3]}}, {{in_name[4]}}, {{name|upper}}_EPSILON, {{name|upper}}_RESCALING);
 {% include "./_save_outputs.jinja" %}
 {% endfilter %}
diff --git a/aidge_export_cpp/templates/kernel_forward/concat_forward.jinja b/aidge_export_cpp/templates/kernel_forward/concat_forward.jinja
new file mode 100644
index 0000000000000000000000000000000000000000..7a77e904db6c18f338f93099f4f117c9285bf6fc
--- /dev/null
+++ b/aidge_export_cpp/templates/kernel_forward/concat_forward.jinja
@@ -0,0 +1,22 @@
+{% filter indent(width=4, first=False) %}
+{% include "./_mem_offset.jinja" %}
+const float* {{ name|upper }}_INPUTS[] = {
+    {%- for i in range(nb_in) -%}
+        {{ in_name[i] }}{{ ", " if not loop.last else "" }}
+    {%- endfor -%}
+};
+
+unsigned int {{ name|upper }}_SIZES[] = {
+    {%- for i in range(nb_in) -%}
+        {{ name|upper }}_INPUT_{{i}}_SIZE{{ ", " if not loop.last else "" }}
+    {%- endfor -%}
+};
+
+concat_forward<{{ name|upper }}_AXIS_SIZE_POST,
+               {{ name|upper }}_AXIS_SIZE_PRE,
+               {{ nb_in }},
+               float> (
+    {{ name|upper }}_INPUTS,
+    {{ name|upper }}_SIZES,
+    {{ out_name[0] }});
+    {% endfilter %}
diff --git a/aidge_export_cpp/templates/kernel_forward/pad_forward.jinja b/aidge_export_cpp/templates/kernel_forward/pad_forward.jinja
new file mode 100644
index 0000000000000000000000000000000000000000..721418709f589d56723156797d7e45afe1259a7b
--- /dev/null
+++ b/aidge_export_cpp/templates/kernel_forward/pad_forward.jinja
@@ -0,0 +1,16 @@
+{% filter indent(width=4, first=False) %}
+{% include "./_mem_offset.jinja" %}
+pad_forward<{{ in_name[0]|upper }}_IN_BATCH,
+            {{ in_name[0]|upper }}_NB_CHANNELS,
+            {{ in_name[0]|upper }}_IN_HEIGHT,
+            {{ in_name[0]|upper }}_IN_WIDTH,
+            {{ out_name[0]|upper }}_NB_OUTPUTS,
+            {{ out_name[0]|upper }}_OUT_HEIGHT,
+            {{ out_name[0]|upper }}_OUT_WIDTH,
+            {{name|upper}}_PADDING_TOP,
+            {{name|upper}}_PADDING_LEFT,
+            {{name|upper}}_PADDING_BOTTOM,
+            {{name|upper}}_PADDING_RIGHT>
+            ({{name|upper}}_BORDER_VALUE, {{in_name[0]}}, {{out_name[0]}});
+{% include "./_save_outputs.jinja" %}
+{% endfilter %}
diff --git a/aidge_export_cpp/templates/kernel_forward/softmax_forward.jinja b/aidge_export_cpp/templates/kernel_forward/softmax_forward.jinja
new file mode 100644
index 0000000000000000000000000000000000000000..7c8e067f34bb2167544bab017e6b581345ba8bb2
--- /dev/null
+++ b/aidge_export_cpp/templates/kernel_forward/softmax_forward.jinja
@@ -0,0 +1,8 @@
+{% filter indent(width=4, first=False) %}
+{% include "./_mem_offset.jinja" %}
+softmax_forward<{{ name|upper }}_AXIS_SIZE,
+                {{ name|upper }}_AXIS_SIZE_POST,
+                {{ name|upper }}_AXIS_SIZE_PRE>
+                ({{in_name[0]}}, {{out_name[0]}});
+{% include "./_save_outputs.jinja" %}
+{% endfilter %}
diff --git a/aidge_export_cpp/unit_tests/test_export.py b/aidge_export_cpp/unit_tests/test_export.py
index d900df83285f9b43a098b00d5b853391e7f97f92..607778d23deda862db73f5908fd1caa6ccc1d95b 100644
--- a/aidge_export_cpp/unit_tests/test_export.py
+++ b/aidge_export_cpp/unit_tests/test_export.py
@@ -3,9 +3,12 @@ import aidge_core
 import aidge_backend_cpu
 import aidge_export_cpp
 import numpy as np
+import operator
+from functools import reduce
 
 import subprocess
 import re
+import shutil
 from aidge_core.utils import run_command
 
 def initFiller(model):
@@ -32,6 +35,32 @@ def initFiller(model):
             else:
                 pass
 
+def _np_init(shape, dtype=np.float32):
+    """
+    Generates a NumPy array with the given shape, filled with random values between -1 and 1
+    with a step of 0.1.
+
+    :param shape: Tuple of dimensions for the array
+    :param dtype: Data type of the output array (default: np.float32)
+    :return: A NumPy array with the given shape and dtype
+    """
+    total_elements = reduce(operator.mul, shape, 1)
+    data = (np.random.randint(0, 21, size=total_elements) - 10) / 10.0
+    return data.reshape(shape).astype(dtype)
+
+def _np_init_ones(shape, default_value=0.01, dtype=np.float32):
+    """
+    Generates a NumPy array with the given shape, filled with random values between -1 and 1
+    with a step of 0.1.
+
+    :param shape: Tuple of dimensions for the array
+    :param dtype: Data type of the output array (default: np.float32)
+    :return: A NumPy array with the given shape and dtype
+    """
+    total_elements = reduce(operator.mul, shape, 1)
+    data = np.ones(total_elements) * default_value
+    return data.reshape(shape).astype(dtype)
+
 
 class test_operator_export(unittest.TestCase):
 
@@ -43,7 +72,7 @@ class test_operator_export(unittest.TestCase):
     def tearDown(self):
         pass
 
-    def unit_test_export(self, graph_view, op_name, in_dims):
+    def unit_test_export(self, graph_view, op_name, in_dims, random_inputs=True, random_weights=True, default_value=0.01):
         """
         TODO:
         * Handle multiple dataformat
@@ -56,14 +85,34 @@ class test_operator_export(unittest.TestCase):
         4- Retrieve standard output and using regex to now if the results are the same
         """
         graph_view.compile("cpu", aidge_core.dtype.float32, dims=in_dims)
+
+        for node in graph_view.get_nodes():
+            if node.type() == "Producer":
+                prod_op = node.get_operator()
+                value = prod_op.get_output(0)
+
+                if (random_weights):
+                    tensor = aidge_core.Tensor(_np_init(value.dims()))
+
+                    node.get_operator().set_output(0, tensor)
+                else:
+                    aidge_core.constant_filler(value, default_value)
+
+
         scheduler = aidge_core.SequentialScheduler(graph_view)
 
-        in_tensor = [aidge_core.Tensor(np.random.random(in_dim).astype(np.float32)) for in_dim in in_dims]
+        if (random_inputs):
+            in_tensor = [aidge_core.Tensor(_np_init(in_dim)) for in_dim in in_dims]
+        else:
+            in_tensor = [aidge_core.Tensor(_np_init_ones(in_dim, default_value)) for in_dim in in_dims]
+
         scheduler.forward(data=in_tensor)
 
         # Note the convention ``<op_name>_test`` is useful for gitignore to avoid pushing generated export by accident.
         export_folder = op_name + "_test"
 
+        shutil.rmtree(export_folder, ignore_errors=True)
+
         # Export the model in C++ standalone
         aidge_core.export_utils.scheduler_export(
                 scheduler,
@@ -112,6 +161,46 @@ class test_operator_export(unittest.TestCase):
 
         self.unit_test_export(model, "FC_flat", [[1, 6, 1, 1]])
 
+    def test_export_softmax(self):
+        print("Softmax")
+        model = aidge_core.sequential([
+            aidge_core.Softmax(axis=1, name="sf0")
+        ])
+
+        self.unit_test_export(model, "Softmax", [[1, 10]])
+
+    def test_export_softmax_batch(self):
+        print("SoftmaxBatch")
+        model = aidge_core.sequential([
+            aidge_core.Softmax(axis=1, name="sf0")
+        ])
+
+        self.unit_test_export(model, "SoftmaxBatch", [[3, 10]])
+
+    def test_export_softmax_axis_2(self):
+        print("SoftmaxAxis2")
+        model = aidge_core.sequential([
+            aidge_core.Softmax(axis=2, name="sf0")
+        ])
+
+        self.unit_test_export(model, "SoftmaxAxis2", [[1, 10, 3, 7]])
+
+    def test_export_softmax_axis_negative(self):
+        print("SoftmaxAxisNegative")
+        model = aidge_core.sequential([
+            aidge_core.Softmax(axis=-3, name="sf0")
+        ])
+
+        self.unit_test_export(model, "SoftmaxAxisNegative", [[1, 10, 3, 7]])
+
+    def test_export_softmax_axis_0(self):
+        print("SoftmaxAxis0")
+        model = aidge_core.sequential([
+            aidge_core.Softmax(axis=0, name="sf0")
+        ])
+
+        self.unit_test_export(model, "SoftmaxAxis0", [[10]])
+
     @unittest.skip("Currently this test is failing")
     def test_export_FC_image_in(self):
         """Test exporting a FC operator with a HWC input.
@@ -122,6 +211,347 @@ class test_operator_export(unittest.TestCase):
         initFiller(model)
         self.unit_test_export(model, "FC_img", [[1, 3, 2, 2]])
 
+    def test_export_relu(self):
+        print("ReLU")
+        model = aidge_core.sequential([
+            aidge_core.ReLU(name="relu0")
+        ])
+
+        self.unit_test_export(model, "ReLU", [[1, 10]])
+
+    def test_export_add(self):
+        print("Add")
+        model = aidge_core.sequential([
+            aidge_core.Producer([1, 5, 5], name="producer"),
+            aidge_core.Add(name="add")
+        ])
+
+        self.unit_test_export(model, "Add", [[1, 5, 5]])
+
+    def test_export_add_larger(self):
+        print("AddLarger")
+        model = aidge_core.sequential([
+            aidge_core.Producer([1, 7, 5], name="producer"),
+            aidge_core.Add(name="add")
+        ])
+
+        self.unit_test_export(model, "Add", [[1, 7, 5]])
+
+    def test_export_add_higher(self):
+        print("AddHigher")
+        model = aidge_core.sequential([
+            aidge_core.Producer([1, 5, 7], name="producer"),
+            aidge_core.Add(name="add")
+        ])
+
+        self.unit_test_export(model, "Add", [[1, 5, 7]])
+
+    # "Broadcast not supported yet in export operator"
+    @unittest.expectedFailure
+    def test_export_add_simple_broadcast(self):
+        print("AddSimpleBroadcast")
+        model = aidge_core.sequential([
+            aidge_core.Producer([1, 1, 5], name="producer"),
+            aidge_core.Add(name="add")
+        ])
+
+        self.unit_test_export(model, "AddSimpleBroadcast", [[1, 7, 5]])
+
+    # "Broadcast not supported yet in export operator"
+    @unittest.expectedFailure
+    def test_export_add_double_broadcast(self):
+        print("AddDoubleBroadcast")
+        model = aidge_core.sequential([
+            aidge_core.Producer([1, 1, 7], name="producer"),
+            aidge_core.Add(name="add")
+        ])
+
+        self.unit_test_export(model, "AddDoubleBroadcast", [[1, 5, 1]])
+
+    def test_export_sub(self):
+        print("Sub")
+        model = aidge_core.sequential([
+            aidge_core.Producer([1, 5, 5], name="producer"),
+            aidge_core.Sub(name="sub")
+        ])
+
+        self.unit_test_export(model, "Sub", [[1, 5, 5]])
+
+    def test_export_sub_larger(self):
+        print("SubLarger")
+        model = aidge_core.sequential([
+            aidge_core.Producer([1, 7, 5], name="producer"),
+            aidge_core.Sub(name="sub")
+        ])
+
+        self.unit_test_export(model, "Sub", [[1, 7, 5]])
+
+    def test_export_sub_higher(self):
+        print("SubHigher")
+        model = aidge_core.sequential([
+            aidge_core.Producer([1, 5, 7], name="producer"),
+            aidge_core.Sub(name="sub")
+        ])
+
+        self.unit_test_export(model, "Sub", [[1, 5, 7]])
+
+    # "Broadcast not supported yet in export operator"
+    @unittest.expectedFailure
+    def test_export_sub_simple_broadcast(self):
+        print("SubSimpleBroadcast")
+        model = aidge_core.sequential([
+            aidge_core.Producer([1, 1, 5], name="producer"),
+            aidge_core.Sub(name="sub")
+        ])
+
+        self.unit_test_export(model, "SubSimpleBroadcast", [[1, 7, 5]])
+
+    # "Broadcast not supported yet in export operator"
+    @unittest.expectedFailure
+    def test_export_sub_double_broadcast(self):
+        print("SubDoubleBroadcast")
+        model = aidge_core.sequential([
+            aidge_core.Producer([1, 1, 7], name="producer"),
+            aidge_core.Sub(name="sub")
+        ])
+
+        self.unit_test_export(model, "SubDoubleBroadcast", [[1, 5, 1]])
+
+    def test_export_mul(self):
+        print("Mul")
+        model = aidge_core.sequential([
+            aidge_core.Producer([1, 5, 5], name="producer"),
+            aidge_core.Mul(name="mul")
+        ])
+
+        self.unit_test_export(model, "Mul", [[1, 5, 5]])
+
+    def test_export_mul_larger(self):
+        print("MulLarger")
+        model = aidge_core.sequential([
+            aidge_core.Producer([1, 7, 5], name="producer"),
+            aidge_core.Mul(name="mul")
+        ])
+
+        self.unit_test_export(model, "Mul", [[1, 7, 5]])
+
+    def test_export_mul_higher(self):
+        print("MulHigher")
+        model = aidge_core.sequential([
+            aidge_core.Producer([1, 5, 7], name="producer"),
+            aidge_core.Mul(name="mul")
+        ])
+
+        self.unit_test_export(model, "Mul", [[1, 5, 7]])
+
+    # "Broadcast not supported yet in export operator"
+    @unittest.expectedFailure
+    def test_export_mul_simple_broadcast(self):
+        print("MulSimpleBroadcast")
+        model = aidge_core.sequential([
+            aidge_core.Producer([1, 1, 5], name="producer"),
+            aidge_core.Mul(name="mul")
+        ])
+
+        self.unit_test_export(model, "MulSimpleBroadcast", [[1, 7, 5]])
+
+    # "Broadcast not supported yet in export operator"
+    @unittest.expectedFailure
+    def test_export_mul_double_broadcast(self):
+        print("MulDoubleBroadcast")
+        model = aidge_core.sequential([
+            aidge_core.Producer([1, 1, 7], name="producer"),
+            aidge_core.Mul(name="mul")
+        ])
+
+        self.unit_test_export(model, "MulDoubleBroadcast", [[1, 5, 1]])
+
+    def test_export_mul_batch(self):
+        print("MulBatch")
+        model = aidge_core.sequential([
+            aidge_core.Producer([3, 5, 7], name="producer"),
+            aidge_core.Mul(name="mul")
+        ])
+
+        self.unit_test_export(model, "MulBatch", [[3, 5, 7]])
+
+    def test_export_concat(self):
+        print("Concat")
+        model = aidge_core.sequential([
+            aidge_core.Producer([1, 5, 7], name="producer"),
+            aidge_core.Concat(nb_inputs=2, axis=1, name="concat")
+        ])
+
+        self.unit_test_export(model, "Concat", [[1, 5, 7]])
+
+    def test_export_concat_axis_2(self):
+        print("ConcatAxis2")
+        model = aidge_core.sequential([
+            aidge_core.Producer([1, 5, 7], name="producer"),
+            aidge_core.Concat(nb_inputs=2, axis=2, name="concat")
+        ])
+
+        self.unit_test_export(model, "ConcatAxis2", [[1, 5, 7]])
+
+    def test_export_concat_axis_negative(self):
+        print("ConcatAxisNegative")
+        model = aidge_core.sequential([
+            aidge_core.Producer([1, 5, 7], name="producer"),
+            aidge_core.Concat(nb_inputs=2, axis=-2, name="concat")
+        ])
+
+        self.unit_test_export(model, "ConcatAxisNegative", [[1, 5, 7]])
+
+    def test_export_conv2D(self):
+        print("Conv2D")
+        model = aidge_core.sequential([
+            aidge_core.Conv2D(in_channels=3, out_channels=3, kernel_dims=(3, 3), name="conv")
+        ])
+
+        self.unit_test_export(model, "Conv2D", [[1, 3, 12, 12]], False, False)
+
+    def test_export_max_pooling(self):
+        print("MaxPooling2D")
+        model = aidge_core.sequential([
+            aidge_core.MaxPooling2D(kernel_dims=(3, 3), name="max_pool")
+        ])
+
+        self.unit_test_export(model, "MaxPooling2D", [[1, 2, 12, 12]], False, False)
+
+    def test_export_avg_pooling(self):
+        print("AvgPooling2D")
+        model = aidge_core.sequential([
+            aidge_core.AvgPooling2D(kernel_dims=(3, 3), name="avg_pool")
+        ])
+
+        self.unit_test_export(model, "AvgPooling2D", [[1, 2, 12, 12]], False, False)
+
+    def test_export_pad2D(self):
+        print("Pad2D")
+        model = aidge_core.sequential([
+            aidge_core.Pad2D((1, 1, 1, 1), name="pad2d")
+        ])
+
+        self.unit_test_export(model, "Pad2D", [[1, 1, 11, 11]])
+
+    def test_export_pad2D_larger(self):
+        print("Pad2DLarger")
+        model = aidge_core.sequential([
+            aidge_core.Pad2D((1, 3, 1, 3), name="pad2d")
+        ])
+
+        self.unit_test_export(model, "Pad2DLarger", [[1, 1, 7, 11]])
+
+    def test_export_pad2D_higher(self):
+        print("Pad2DHigher")
+        model = aidge_core.sequential([
+            aidge_core.Pad2D((3, 1, 3, 1), name="pad2d")
+        ])
+
+        self.unit_test_export(model, "Pad2DHigher", [[1, 1, 11, 7]])
+
+    def test_export_pad2D_mismatch(self):
+        print("Pad2DMismatch")
+        model = aidge_core.sequential([
+            aidge_core.Pad2D((1, 3, 5, 7), name="pad2d")
+        ])
+
+        self.unit_test_export(model, "Pad2DMismatch", [[3, 5, 11, 7]])
+
+    def test_export_pad2D_denser(self):
+        print("Pad2DDenser")
+        model = aidge_core.sequential([
+            aidge_core.Pad2D((3, 3, 3, 3), name="pad2d")
+        ])
+
+        self.unit_test_export(model, "Pad2DDenser", [[1, 5, 7, 11]])
+
+    def test_export_pad2D_with_bigger_batch_size(self):
+        print("Pad2DBiggerBatchSize")
+        model = aidge_core.sequential([
+            aidge_core.Pad2D((1, 1, 1, 1), name="pad2d")
+        ])
+
+        self.unit_test_export(model, "Pad2DBiggerBatchSize", [[3, 5, 7, 11]])
+
+    @unittest.expectedFailure
+    def test_export_pad2D_not_constant(self):
+        print("Pad2DNotConstant")
+        model = aidge_core.sequential([
+            aidge_core.Pad2D((3, 3, 3, 3), border_type=aidge_core.pad_border_type.Wrap, name="pad2d")
+        ])
+
+        self.unit_test_export(model, "Pad2DNotConstant", [[1, 5, 7, 11]])
+
+    def test_export_batchnorm2D(self):
+        print("BatchNormalization2D")
+        model = aidge_core.sequential([
+            aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn")
+        ])
+
+        self.unit_test_export(model, "BatchNorm2D", [[1, 1, 5, 5]], False, False)
+
+    def test_export_batchnorm2D_Larger(self):
+        print("BatchNormalization2DLarger")
+        model = aidge_core.sequential([
+            aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn")
+        ])
+
+        self.unit_test_export(model, "BatchNorm2DLarger", [[1, 1, 5, 7]], False, False)
+
+    def test_export_batchnorm2D_Higher(self):
+        print("BatchNormalization2DHigher")
+        model = aidge_core.sequential([
+            aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn")
+        ])
+
+        self.unit_test_export(model, "BatchNorm2DHigher", [[1, 1, 7, 5]], False, False)
+
+    def test_export_batchnorm2D_Denser(self):
+        print("BatchNormalization2DDenser")
+        model = aidge_core.sequential([
+            aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn")
+        ])
+
+        self.unit_test_export(model, "BatchNorm2DDenser", [[1, 3, 5, 7]], False, False)
+
+    def test_export_batchnorm2D_with_bigger_batch_size(self):
+        print("BatchNormalization2DBiggerBatchSize")
+        model = aidge_core.sequential([
+            aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn")
+        ])
+
+        self.unit_test_export(model, "BatchNormalization2DBiggerBatchSize", [[4, 3, 5, 7]], False, False)
+
+
+    def test_export_batchnorm2D_Larger(self):
+        print("BatchNormalization2DLarger")
+        model = aidge_core.sequential([
+            aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn")
+        ])
+
+        self.unit_test_export(model, "BatchNorm2DLarger", [[1, 1, 5, 7]], False, False)
+
+    def test_export_batchnorm2D_Higher(self):
+        print("BatchNormalization2DHigher")
+        model = aidge_core.sequential([
+            aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn")
+        ])
+
+        self.unit_test_export(model, "BatchNorm2DHigher", [[1, 1, 7, 5]], False, False)
+
+    def test_export_batchnorm2D_Denser(self):
+        print("BatchNormalization2DDenser")
+        model = aidge_core.sequential([
+            aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn")
+        ])
+
+        self.unit_test_export(model, "BatchNorm2DDenser", [[1, 3, 5, 7]], False, False)
+
+
+    def test_export_cpp(self):
+        print("Export test to do")
+
     def test_export_Conv(self):
         model = aidge_core.sequential([
             aidge_core.Conv2D(1, 1, [3, 3], name="InputNode")