From 504b2db05978d2c6c9716b1059f7ada6dfc2cb88 Mon Sep 17 00:00:00 2001
From: Gallasko <gallasko@gmail.com>
Date: Thu, 6 Mar 2025 11:35:46 +0100
Subject: [PATCH] feat: Added Softmax export op

---
 .gitignore                                    |  3 +
 aidge_export_cpp/kernels/softmax.hpp          | 55 +++++++++++++++++++
 aidge_export_cpp/operators.py                 | 15 +++++
 .../configuration/softmax_config.jinja        | 12 ++++
 .../kernel_forward/softmax_forward.jinja      | 12 ++++
 aidge_export_cpp/unit_tests/test_export.py    |  8 +++
 6 files changed, 105 insertions(+)
 create mode 100644 aidge_export_cpp/kernels/softmax.hpp
 create mode 100644 aidge_export_cpp/templates/configuration/softmax_config.jinja
 create mode 100644 aidge_export_cpp/templates/kernel_forward/softmax_forward.jinja

diff --git a/.gitignore b/.gitignore
index 67ffbef..93bcfd3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -14,6 +14,9 @@ dist*/
 aidge_export_cpp/_version.py
 wheelhouse/*
 
+# Temp test folders
+aidge_export_cpp/unit_tests/*_temp_test
+
 # Mermaid
 *.mmd
 
diff --git a/aidge_export_cpp/kernels/softmax.hpp b/aidge_export_cpp/kernels/softmax.hpp
new file mode 100644
index 0000000..73d00da
--- /dev/null
+++ b/aidge_export_cpp/kernels/softmax.hpp
@@ -0,0 +1,55 @@
+#ifndef __AIDGE_EXPORT_CPP_KERNELS_SOFTMAX__
+#define __AIDGE_EXPORT_CPP_KERNELS_SOFTMAX__
+
+#include "network/typedefs.hpp"
+#include "network/utils.hpp"
+#include "kernels/macs.hpp"
+
+#include <type_traits>
+
+#include <cmath>
+
+template<int NB_CHANNELS,
+         int CHANNELS_HEIGHT, int CHANNELS_WIDTH,
+         int NB_OUTPUTS,
+         int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH,
+         int AXIS,
+         typename Input_T, typename Output_T>
+__attribute__((always_inline)) inline
+void softmax_forward (
+    const Input_T* __restrict inputs,
+    Output_T* __restrict outputs)
+{
+    Input_T maxValue = 0.0f;
+
+    for (int och = 0; och < NB_OUTPUTS; och++) {
+        maxValue = std::max(maxValue, inputs[och]);
+    }
+
+    Input_T sumExp = 0.0f;
+
+    if constexpr (std::is_same_v<Input_T, Output_T>) {
+        for (int och = 0; och < NB_OUTPUTS; och++) {
+            // This should be both more performant while keeping the same memory footprint but we can only use it if INPUT_T and OUTPUT_T types are the same !
+            outputs[och] = std::exp(inputs[och] - maxValue);
+            sumExp += outputs[och];
+        }
+
+        for (int och = 0; och < NB_OUTPUTS; och++) {
+            outputs[och] /= sumExp;
+        }
+    }
+    else
+    {
+        for (int och = 0; och < NB_OUTPUTS; och++) {
+            sumExp += std::exp(inputs[och] - maxValue);
+        }
+
+        for (int och = 0; och < NB_OUTPUTS; och++) {
+            outputs[och] = std::exp(inputs[och] - maxValue) / sumExp;
+        }
+    }
+}
+
+
+#endif  // __AIDGE_EXPORT_CPP_KERNELS_SOFTMAX__
diff --git a/aidge_export_cpp/operators.py b/aidge_export_cpp/operators.py
index 346928f..0f6d3c8 100644
--- a/aidge_export_cpp/operators.py
+++ b/aidge_export_cpp/operators.py
@@ -302,4 +302,19 @@ class TransposeCPP(ExportNodeCpp):
         self.include_list = []
         self.kernels_to_copy = [
             str(ROOT / "kernels" / "transpose.hpp")
+        ]
+
+@ExportLibCpp.register("Softmax", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
+class SoftmaxCPP(ExportNodeCpp):
+    def __init__(self, node, mem_info):
+        super().__init__(node, mem_info)
+        self.attributes["axis"] = node.get_operator().attr.axis
+        self.config_template = str(
+            ROOT / "templates" / "configuration" / "softmax_config.jinja")
+        self.forward_template = str(
+            ROOT / "templates" / "kernel_forward" / "softmax_forward.jinja")
+        self.include_list = []
+        self.kernels_to_copy = [
+            str(ROOT / "kernels" / "softmax.hpp"),
+            str(ROOT / "kernels" / "macs.hpp"),
         ]
\ No newline at end of file
diff --git a/aidge_export_cpp/templates/configuration/softmax_config.jinja b/aidge_export_cpp/templates/configuration/softmax_config.jinja
new file mode 100644
index 0000000..d8ec8af
--- /dev/null
+++ b/aidge_export_cpp/templates/configuration/softmax_config.jinja
@@ -0,0 +1,12 @@
+{#- For name header -#}
+#ifndef {{ name|upper }}_LAYER_H
+#define {{ name|upper }}_LAYER_H
+{# For layer configuration -#}
+{% include "./_def_io.jinja" %}
+{% include "./_meminfo.jinja" %}
+
+{#- Calculate sizes #}
+{%- set weights_size = out_chan[0] * in_chan[0] * in_height[0] * in_width[0] %}
+#define {{ name|upper }}_AXIS {{ axis }}
+
+#endif /* {{ name|upper }}_LAYER_H */
diff --git a/aidge_export_cpp/templates/kernel_forward/softmax_forward.jinja b/aidge_export_cpp/templates/kernel_forward/softmax_forward.jinja
new file mode 100644
index 0000000..607ad53
--- /dev/null
+++ b/aidge_export_cpp/templates/kernel_forward/softmax_forward.jinja
@@ -0,0 +1,12 @@
+{% filter indent(width=4, first=False) %}
+{% include "./_mem_offset.jinja" %}
+softmax_forward<{{ in_name[0]|upper }}_NB_CHANNELS,
+                       {{ in_name[0]|upper }}_IN_HEIGHT,
+                       {{ in_name[0]|upper }}_IN_WIDTH,
+                       {{ out_name[0]|upper }}_NB_OUTPUTS,
+                       {{ out_name[0]|upper }}_OUT_HEIGHT,
+                       {{ out_name[0]|upper }}_OUT_WIDTH,
+                       {{ name|upper }}_AXIS>
+                       ({{in_name[0]}}, {{out_name[0]}});
+{% include "./_save_outputs.jinja" %}
+{% endfilter %}
diff --git a/aidge_export_cpp/unit_tests/test_export.py b/aidge_export_cpp/unit_tests/test_export.py
index d900df8..27280fe 100644
--- a/aidge_export_cpp/unit_tests/test_export.py
+++ b/aidge_export_cpp/unit_tests/test_export.py
@@ -112,6 +112,14 @@ class test_operator_export(unittest.TestCase):
 
         self.unit_test_export(model, "FC_flat", [[1, 6, 1, 1]])
 
+    def test_export_softmax(self):
+        print("Softmax")
+        model = aidge_core.sequential([
+            aidge_core.Softmax(axis=1, name="sf0")
+        ])
+
+        self.assertTrue(unit_test_export(model, "Softmax", [[1, 10]]))
+
     @unittest.skip("Currently this test is failing")
     def test_export_FC_image_in(self):
         """Test exporting a FC operator with a HWC input.
-- 
GitLab