From 7c3813bbfa2987fc711cfe5f73ed9dfe3325e98b Mon Sep 17 00:00:00 2001
From: Gallasko <gallasko@gmail.com>
Date: Fri, 4 Apr 2025 16:50:31 +0200
Subject: [PATCH] feat: Softmax works with any number of dimensions

---
 aidge_export_cpp/kernels/softmax.hpp          | 66 +++++++++----------
 aidge_export_cpp/operators.py                 | 24 +++++++
 .../configuration/softmax_config.jinja        |  4 +-
 .../kernel_forward/softmax_forward.jinja      | 12 ++--
 aidge_export_cpp/unit_tests/test_export.py    | 24 +++++++
 5 files changed, 87 insertions(+), 43 deletions(-)

diff --git a/aidge_export_cpp/kernels/softmax.hpp b/aidge_export_cpp/kernels/softmax.hpp
index 73d00da..f5472cf 100644
--- a/aidge_export_cpp/kernels/softmax.hpp
+++ b/aidge_export_cpp/kernels/softmax.hpp
@@ -6,50 +6,48 @@
 #include "kernels/macs.hpp"
 
 #include <type_traits>
-
 #include <cmath>
+#include <algorithm>
 
-template<int NB_CHANNELS,
-         int CHANNELS_HEIGHT, int CHANNELS_WIDTH,
-         int NB_OUTPUTS,
-         int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH,
-         int AXIS,
+template<int AXIS_SIZE,
+         int AXIS_SIZE_POST,
+         int AXIS_SIZE_PRE,
          typename Input_T, typename Output_T>
 __attribute__((always_inline)) inline
 void softmax_forward (
     const Input_T* __restrict inputs,
     Output_T* __restrict outputs)
 {
-    Input_T maxValue = 0.0f;
-
-    for (int och = 0; och < NB_OUTPUTS; och++) {
-        maxValue = std::max(maxValue, inputs[och]);
-    }
-
-    Input_T sumExp = 0.0f;
-
-    if constexpr (std::is_same_v<Input_T, Output_T>) {
-        for (int och = 0; och < NB_OUTPUTS; och++) {
-            // This should be both more performant while keeping the same memory footprint but we can only use it if INPUT_T and OUTPUT_T types are the same !
-            outputs[och] = std::exp(inputs[och] - maxValue);
-            sumExp += outputs[och];
-        }
-
-        for (int och = 0; och < NB_OUTPUTS; och++) {
-            outputs[och] /= sumExp;
-        }
-    }
-    else
-    {
-        for (int och = 0; och < NB_OUTPUTS; och++) {
-            sumExp += std::exp(inputs[och] - maxValue);
-        }
-
-        for (int och = 0; och < NB_OUTPUTS; och++) {
-            outputs[och] = std::exp(inputs[och] - maxValue) / sumExp;
+    // Iterate over the "pre-axis" and "post-axis" slices.
+    // For each slice along the axis, compute the maximum value,
+    // the sum of exponentials, and then write the normalized softmax outputs.
+    for (int i = 0; i < AXIS_SIZE_PRE; ++i) {
+        for (int j = 0; j < AXIS_SIZE_POST; ++j) {
+            // Compute the base index for this slice.
+            const int baseIdx = i * AXIS_SIZE * AXIS_SIZE_POST + j;
+
+            // Find the maximum value along the axis.
+            Input_T maxVal = inputs[baseIdx];
+            for (int k = 1; k < AXIS_SIZE; ++k) {
+                const int idx = baseIdx + k * AXIS_SIZE_POST;
+                maxVal = std::max(maxVal, inputs[idx]);
+            }
+
+            // Compute the sum of the exponentials along the axis.
+            Input_T sumExp = 0;
+            for (int k = 0; k < AXIS_SIZE; ++k) {
+                const int idx = baseIdx + k * AXIS_SIZE_POST;
+                outputs[idx] = std::exp(inputs[idx] - maxVal);
+                sumExp += outputs[idx];
+            }
+
+            // Write the softmax values to the output.
+            for (int k = 0; k < AXIS_SIZE; ++k) {
+                const int idx = baseIdx + k * AXIS_SIZE_POST;
+                outputs[idx] /= sumExp;
+            }
         }
     }
 }
 
-
 #endif  // __AIDGE_EXPORT_CPP_KERNELS_SOFTMAX__
diff --git a/aidge_export_cpp/operators.py b/aidge_export_cpp/operators.py
index a6ad95d..7c22cdb 100644
--- a/aidge_export_cpp/operators.py
+++ b/aidge_export_cpp/operators.py
@@ -338,6 +338,30 @@ class SoftmaxCPP(ExportNodeCpp):
     def __init__(self, node, mem_info):
         super().__init__(node, mem_info)
         self.attributes["axis"] = node.get_operator().attr.axis
+
+        assert self.node.get_nb_inputs() == 1, (
+            f"export softmax: nb_inputs == {self.node.get_nb_inputs()} not implemented"
+        )
+
+        tensor = self.operator.get_input(0)
+        nbDims = len(tensor.dims())
+
+        assert self.attributes["axis"] < nbDims, (
+            f"export softmax: attribute axis == {node.get_operator().attr.axis} should be less than {nbDims}"
+        )
+
+        postAxisElems = 1
+        for i in range(self.attributes["axis"] + 1, nbDims):
+            postAxisElems *= tensor.dims()[i]
+
+        preAxisElems = 1
+        for i in range(self.attributes["axis"]):
+            preAxisElems *= tensor.dims()[i]
+
+        self.attributes["axis_size"] = tensor.dims()[self.attributes["axis"]]
+        self.attributes["axis_size_post"] = postAxisElems
+        self.attributes["axis_size_pre"] = preAxisElems
+
         self.config_template = str(
             ROOT / "templates" / "configuration" / "softmax_config.jinja")
         self.forward_template = str(
diff --git a/aidge_export_cpp/templates/configuration/softmax_config.jinja b/aidge_export_cpp/templates/configuration/softmax_config.jinja
index d8ec8af..e9661bc 100644
--- a/aidge_export_cpp/templates/configuration/softmax_config.jinja
+++ b/aidge_export_cpp/templates/configuration/softmax_config.jinja
@@ -7,6 +7,8 @@
 
 {#- Calculate sizes #}
 {%- set weights_size = out_chan[0] * in_chan[0] * in_height[0] * in_width[0] %}
-#define {{ name|upper }}_AXIS {{ axis }}
+#define {{ name|upper }}_AXIS_SIZE {{ axis_size }}
+#define {{ name|upper }}_AXIS_SIZE_POST {{ axis_size_post }}
+#define {{ name|upper }}_AXIS_SIZE_PRE {{ axis_size_pre }}
 
 #endif /* {{ name|upper }}_LAYER_H */
diff --git a/aidge_export_cpp/templates/kernel_forward/softmax_forward.jinja b/aidge_export_cpp/templates/kernel_forward/softmax_forward.jinja
index 607ad53..7c8e067 100644
--- a/aidge_export_cpp/templates/kernel_forward/softmax_forward.jinja
+++ b/aidge_export_cpp/templates/kernel_forward/softmax_forward.jinja
@@ -1,12 +1,8 @@
 {% filter indent(width=4, first=False) %}
 {% include "./_mem_offset.jinja" %}
-softmax_forward<{{ in_name[0]|upper }}_NB_CHANNELS,
-                       {{ in_name[0]|upper }}_IN_HEIGHT,
-                       {{ in_name[0]|upper }}_IN_WIDTH,
-                       {{ out_name[0]|upper }}_NB_OUTPUTS,
-                       {{ out_name[0]|upper }}_OUT_HEIGHT,
-                       {{ out_name[0]|upper }}_OUT_WIDTH,
-                       {{ name|upper }}_AXIS>
-                       ({{in_name[0]}}, {{out_name[0]}});
+softmax_forward<{{ name|upper }}_AXIS_SIZE,
+                {{ name|upper }}_AXIS_SIZE_POST,
+                {{ name|upper }}_AXIS_SIZE_PRE>
+                ({{in_name[0]}}, {{out_name[0]}});
 {% include "./_save_outputs.jinja" %}
 {% endfilter %}
diff --git a/aidge_export_cpp/unit_tests/test_export.py b/aidge_export_cpp/unit_tests/test_export.py
index d8e7814..3d55f11 100644
--- a/aidge_export_cpp/unit_tests/test_export.py
+++ b/aidge_export_cpp/unit_tests/test_export.py
@@ -169,6 +169,30 @@ class test_operator_export(unittest.TestCase):
 
         self.unit_test_export(model, "Softmax", [[1, 10]])
 
+    def test_export_softmax_batch(self):
+        print("SoftmaxBatch")
+        model = aidge_core.sequential([
+            aidge_core.Softmax(axis=1, name="sf0")
+        ])
+
+        self.unit_test_export(model, "SoftmaxBatch", [[3, 10]])
+
+    def test_export_softmax_axis_2(self):
+        print("SoftmaxAxis2")
+        model = aidge_core.sequential([
+            aidge_core.Softmax(axis=2, name="sf0")
+        ])
+
+        self.unit_test_export(model, "SoftmaxAxis2", [[1, 10, 3, 7]])
+
+    def test_export_softmax_axis_0(self):
+        print("SoftmaxAxis0")
+        model = aidge_core.sequential([
+            aidge_core.Softmax(axis=0, name="sf0")
+        ])
+
+        self.unit_test_export(model, "SoftmaxAxis0", [[10]])
+
     @unittest.skip("Currently this test is failing")
     def test_export_FC_image_in(self):
         """Test exporting a FC operator with a HWC input.
-- 
GitLab