diff --git a/aidge_export_cpp/kernels/softmax.hpp b/aidge_export_cpp/kernels/softmax.hpp
index 73d00da058ac53c7c625ae66d65a9aead19559a4..5e2444a0b10d874c64bbb384555278b0208617f9 100644
--- a/aidge_export_cpp/kernels/softmax.hpp
+++ b/aidge_export_cpp/kernels/softmax.hpp
@@ -6,8 +6,8 @@
 #include "kernels/macs.hpp"
 
 #include <type_traits>
-
 #include <cmath>
+#include <algorithm>
 
 template<int NB_CHANNELS,
          int CHANNELS_HEIGHT, int CHANNELS_WIDTH,
@@ -20,36 +20,56 @@ void softmax_forward (
     const Input_T* __restrict inputs,
     Output_T* __restrict outputs)
 {
-    Input_T maxValue = 0.0f;
+    // Todo those dims calculation cannot work as this operator can take an arbitrary number of dimensions
+    // This currently only works for axis 0 and 1 but to make it work correctly we need to pass the dims and dims size as
+    // arguments to the function
+    constexpr int nDims = 3;
+    constexpr int dims[3] = {NB_CHANNELS, NB_OUTPUTS, CHANNELS_WIDTH};
+    constexpr int axisIdx = (AXIS < 0) ? AXIS + nDims : AXIS;
+    constexpr int preAxisElems =
+          (axisIdx == 0) ? 1
+        : (axisIdx == 1) ? 1
+        : (axisIdx == 2) ? NB_CHANNELS * CHANNELS_HEIGHT
+        : 0;  // Should not occur if axisIdx is valid
+    constexpr int axisSize = dims[axisIdx];
+    constexpr int postAxisElems =
+          (axisIdx == 2) ? 1
+        : (axisIdx == 1) ? 1
+        : (axisIdx == 0) ? CHANNELS_HEIGHT * CHANNELS_WIDTH
+        : 0;  // Should not occur if axisIdx is valid
 
-    for (int och = 0; och < NB_OUTPUTS; och++) {
-        maxValue = std::max(maxValue, inputs[och]);
-    }
 
-    Input_T sumExp = 0.0f;
 
-    if constexpr (std::is_same_v<Input_T, Output_T>) {
-        for (int och = 0; och < NB_OUTPUTS; och++) {
-            // This should be both more performant while keeping the same memory footprint but we can only use it if INPUT_T and OUTPUT_T types are the same !
-            outputs[och] = std::exp(inputs[och] - maxValue);
-            sumExp += outputs[och];
-        }
+    // Iterate over the "pre-axis" and "post-axis" slices.
+    // For each slice along the axis, compute the maximum value,
+    // the sum of exponentials, and then write the normalized softmax outputs.
+    for (int i = 0; i < preAxisElems; ++i) {
+        for (int j = 0; j < postAxisElems; ++j) {
+            // Compute the base index for this slice.
+            const int baseIdx = i * axisSize * postAxisElems + j;
 
-        for (int och = 0; och < NB_OUTPUTS; och++) {
-            outputs[och] /= sumExp;
-        }
-    }
-    else
-    {
-        for (int och = 0; och < NB_OUTPUTS; och++) {
-            sumExp += std::exp(inputs[och] - maxValue);
-        }
+            // Find the maximum value along the axis.
+            Input_T maxVal = inputs[baseIdx];
+            for (int k = 1; k < axisSize; ++k) {
+                const int idx = i * axisSize * postAxisElems + k * postAxisElems + j;
+                maxVal = std::max(maxVal, inputs[idx]);
+            }
 
-        for (int och = 0; och < NB_OUTPUTS; och++) {
-            outputs[och] = std::exp(inputs[och] - maxValue) / sumExp;
+            // Compute the sum of the exponentials along the axis.
+            Input_T sumExp = 0;
+            for (int k = 0; k < axisSize; ++k) {
+                const int idx = i * axisSize * postAxisElems + k * postAxisElems + j;
+                outputs[idx] = std::exp(inputs[idx] - maxVal);
+                sumExp += outputs[idx];
+            }
+
+            // Write the softmax values to the output.
+            for (int k = 0; k < axisSize; ++k) {
+                const int idx = i * axisSize * postAxisElems + k * postAxisElems + j;
+                outputs[idx] /= sumExp;
+            }
         }
     }
 }
 
-
 #endif  // __AIDGE_EXPORT_CPP_KERNELS_SOFTMAX__
diff --git a/aidge_export_cpp/operators.py b/aidge_export_cpp/operators.py
index 90d039fe73c93a380fa85baa939a61b207e48fba..2fdb82bc91f8ca502fee9107edc7b6641bf9a33d 100644
--- a/aidge_export_cpp/operators.py
+++ b/aidge_export_cpp/operators.py
@@ -334,6 +334,11 @@ class SoftmaxCPP(ExportNodeCpp):
     def __init__(self, node, mem_info):
         super().__init__(node, mem_info)
         self.attributes["axis"] = node.get_operator().attr.axis
+
+        assert self.attributes["axis"] == 0 or self.attributes["axis"] == 1, (
+            f"export softmax: attribute axis == {node.get_operator().attr.axis} not implemented"
+        )
+
         self.config_template = str(
             ROOT / "templates" / "configuration" / "softmax_config.jinja")
         self.forward_template = str(
diff --git a/aidge_export_cpp/unit_tests/test_export.py b/aidge_export_cpp/unit_tests/test_export.py
index 4f9ea7a175ef238c8103081c92e78dcd8e663066..51eaf7bd59f91b21b8d94842843157d6acf1c5b9 100644
--- a/aidge_export_cpp/unit_tests/test_export.py
+++ b/aidge_export_cpp/unit_tests/test_export.py
@@ -166,6 +166,23 @@ class test_operator_export(unittest.TestCase):
 
         self.unit_test_export(model, "Softmax", [[1, 10]])
 
+    @unittest.expectedFailure
+    def test_export_softmax_axis_2(self):
+        print("SoftmaxAxis2")
+        model = aidge_core.sequential([
+            aidge_core.Softmax(axis=2, name="sf0")
+        ])
+
+        self.unit_test_export(model, "SoftmaxAxis2", [[1, 10, 3, 7]])
+
+    def test_export_softmax_axis_0(self):
+        print("SoftmaxAxis0")
+        model = aidge_core.sequential([
+            aidge_core.Softmax(axis=0, name="sf0")
+        ])
+
+        self.unit_test_export(model, "SoftmaxAxis0", [[10]])
+
     @unittest.skip("Currently this test is failing")
     def test_export_FC_image_in(self):
         """Test exporting a FC operator with a HWC input.