diff --git a/aidge_export_cpp/kernels/softmax.hpp b/aidge_export_cpp/kernels/softmax.hpp
index 73d00da058ac53c7c625ae66d65a9aead19559a4..f5472cf6d807bc2f547e58616943f6e72dccd80e 100644
--- a/aidge_export_cpp/kernels/softmax.hpp
+++ b/aidge_export_cpp/kernels/softmax.hpp
@@ -6,50 +6,48 @@
 #include "kernels/macs.hpp"
 
 #include <type_traits>
-
 #include <cmath>
+#include <algorithm>
 
-template<int NB_CHANNELS,
-         int CHANNELS_HEIGHT, int CHANNELS_WIDTH,
-         int NB_OUTPUTS,
-         int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH,
-         int AXIS,
+template<int AXIS_SIZE,
+         int AXIS_SIZE_POST,
+         int AXIS_SIZE_PRE,
          typename Input_T, typename Output_T>
 __attribute__((always_inline)) inline
 void softmax_forward (
     const Input_T* __restrict inputs,
     Output_T* __restrict outputs)
 {
-    Input_T maxValue = 0.0f;
-
-    for (int och = 0; och < NB_OUTPUTS; och++) {
-        maxValue = std::max(maxValue, inputs[och]);
-    }
-
-    Input_T sumExp = 0.0f;
-
-    if constexpr (std::is_same_v<Input_T, Output_T>) {
-        for (int och = 0; och < NB_OUTPUTS; och++) {
-            // This should be both more performant while keeping the same memory footprint but we can only use it if INPUT_T and OUTPUT_T types are the same !
-            outputs[och] = std::exp(inputs[och] - maxValue);
-            sumExp += outputs[och];
-        }
-
-        for (int och = 0; och < NB_OUTPUTS; och++) {
-            outputs[och] /= sumExp;
-        }
-    }
-    else
-    {
-        for (int och = 0; och < NB_OUTPUTS; och++) {
-            sumExp += std::exp(inputs[och] - maxValue);
-        }
-
-        for (int och = 0; och < NB_OUTPUTS; och++) {
-            outputs[och] = std::exp(inputs[och] - maxValue) / sumExp;
+    // Iterate over the "pre-axis" and "post-axis" slices.
+    // For each slice along the axis, compute the maximum value,
+    // the sum of exponentials, and then write the normalized softmax outputs.
+    for (int i = 0; i < AXIS_SIZE_PRE; ++i) {
+        for (int j = 0; j < AXIS_SIZE_POST; ++j) {
+            // Compute the base index for this slice.
+            const int baseIdx = i * AXIS_SIZE * AXIS_SIZE_POST + j;
+
+            // Find the maximum value along the axis.
+            Input_T maxVal = inputs[baseIdx];
+            for (int k = 1; k < AXIS_SIZE; ++k) {
+                const int idx = baseIdx + k * AXIS_SIZE_POST;
+                maxVal = std::max(maxVal, inputs[idx]);
+            }
+
+            // Compute the sum of the exponentials along the axis.
+            Input_T sumExp = 0;
+            for (int k = 0; k < AXIS_SIZE; ++k) {
+                const int idx = baseIdx + k * AXIS_SIZE_POST;
+                outputs[idx] = std::exp(inputs[idx] - maxVal);
+                sumExp += outputs[idx];
+            }
+
+            // Write the softmax values to the output.
+            for (int k = 0; k < AXIS_SIZE; ++k) {
+                const int idx = baseIdx + k * AXIS_SIZE_POST;
+                outputs[idx] /= sumExp;
+            }
         }
     }
 }
 
-
 #endif  // __AIDGE_EXPORT_CPP_KERNELS_SOFTMAX__
diff --git a/aidge_export_cpp/operators.py b/aidge_export_cpp/operators.py
index a6ad95d26c0d7649611a6940646c7b4cd72364e5..7c22cdb7af01392e0e3deb05bf4aecc16565e6a9 100644
--- a/aidge_export_cpp/operators.py
+++ b/aidge_export_cpp/operators.py
@@ -338,6 +338,30 @@ class SoftmaxCPP(ExportNodeCpp):
     def __init__(self, node, mem_info):
         super().__init__(node, mem_info)
         self.attributes["axis"] = node.get_operator().attr.axis
+
+        assert self.node.get_nb_inputs() == 1, (
+            f"export softmax: nb_inputs == {self.node.get_nb_inputs()} not implemented"
+        )
+
+        tensor = self.operator.get_input(0)
+        nbDims = len(tensor.dims())
+
+        assert self.attributes["axis"] < nbDims, (
+            f"export softmax: attribute axis == {node.get_operator().attr.axis} should be less than {nbDims}"
+        )
+
+        postAxisElems = 1
+        for i in range(self.attributes["axis"] + 1, nbDims):
+            postAxisElems *= tensor.dims()[i]
+
+        preAxisElems = 1
+        for i in range(self.attributes["axis"]):
+            preAxisElems *= tensor.dims()[i]
+
+        self.attributes["axis_size"] = tensor.dims()[self.attributes["axis"]]
+        self.attributes["axis_size_post"] = postAxisElems
+        self.attributes["axis_size_pre"] = preAxisElems
+
         self.config_template = str(
             ROOT / "templates" / "configuration" / "softmax_config.jinja")
         self.forward_template = str(
diff --git a/aidge_export_cpp/templates/configuration/softmax_config.jinja b/aidge_export_cpp/templates/configuration/softmax_config.jinja
index d8ec8af05d07bad5cc67fa6ec092400e42d6a9df..e9661bc553bfefb5a0fb12be5fe87106ac90e4a9 100644
--- a/aidge_export_cpp/templates/configuration/softmax_config.jinja
+++ b/aidge_export_cpp/templates/configuration/softmax_config.jinja
@@ -7,6 +7,8 @@
 
 {#- Calculate sizes #}
 {%- set weights_size = out_chan[0] * in_chan[0] * in_height[0] * in_width[0] %}
-#define {{ name|upper }}_AXIS {{ axis }}
+#define {{ name|upper }}_AXIS_SIZE {{ axis_size }}
+#define {{ name|upper }}_AXIS_SIZE_POST {{ axis_size_post }}
+#define {{ name|upper }}_AXIS_SIZE_PRE {{ axis_size_pre }}
 
 #endif /* {{ name|upper }}_LAYER_H */
diff --git a/aidge_export_cpp/templates/kernel_forward/softmax_forward.jinja b/aidge_export_cpp/templates/kernel_forward/softmax_forward.jinja
index 607ad53005ba5be9a0857e64f7196ca9b1be7c06..7c8e067f34bb2167544bab017e6b581345ba8bb2 100644
--- a/aidge_export_cpp/templates/kernel_forward/softmax_forward.jinja
+++ b/aidge_export_cpp/templates/kernel_forward/softmax_forward.jinja
@@ -1,12 +1,8 @@
 {% filter indent(width=4, first=False) %}
 {% include "./_mem_offset.jinja" %}
-softmax_forward<{{ in_name[0]|upper }}_NB_CHANNELS,
-                       {{ in_name[0]|upper }}_IN_HEIGHT,
-                       {{ in_name[0]|upper }}_IN_WIDTH,
-                       {{ out_name[0]|upper }}_NB_OUTPUTS,
-                       {{ out_name[0]|upper }}_OUT_HEIGHT,
-                       {{ out_name[0]|upper }}_OUT_WIDTH,
-                       {{ name|upper }}_AXIS>
-                       ({{in_name[0]}}, {{out_name[0]}});
+softmax_forward<{{ name|upper }}_AXIS_SIZE,
+                {{ name|upper }}_AXIS_SIZE_POST,
+                {{ name|upper }}_AXIS_SIZE_PRE>
+                ({{in_name[0]}}, {{out_name[0]}});
 {% include "./_save_outputs.jinja" %}
 {% endfilter %}
diff --git a/aidge_export_cpp/unit_tests/test_export.py b/aidge_export_cpp/unit_tests/test_export.py
index d8e7814e204c91534f53571bec6c84ada4f590e8..3d55f114f81c65ac14022eb4d9395aecb92285fc 100644
--- a/aidge_export_cpp/unit_tests/test_export.py
+++ b/aidge_export_cpp/unit_tests/test_export.py
@@ -169,6 +169,30 @@ class test_operator_export(unittest.TestCase):
 
         self.unit_test_export(model, "Softmax", [[1, 10]])
 
+    def test_export_softmax_batch(self):
+        print("SoftmaxBatch")
+        model = aidge_core.sequential([
+            aidge_core.Softmax(axis=1, name="sf0")
+        ])
+
+        self.unit_test_export(model, "SoftmaxBatch", [[3, 10]])
+
+    def test_export_softmax_axis_2(self):
+        print("SoftmaxAxis2")
+        model = aidge_core.sequential([
+            aidge_core.Softmax(axis=2, name="sf0")
+        ])
+
+        self.unit_test_export(model, "SoftmaxAxis2", [[1, 10, 3, 7]])
+
+    def test_export_softmax_axis_0(self):
+        print("SoftmaxAxis0")
+        model = aidge_core.sequential([
+            aidge_core.Softmax(axis=0, name="sf0")
+        ])
+
+        self.unit_test_export(model, "SoftmaxAxis0", [[10]])
+
     @unittest.skip("Currently this test is failing")
     def test_export_FC_image_in(self):
         """Test exporting a FC operator with a HWC input.