From 7c3813bbfa2987fc711cfe5f73ed9dfe3325e98b Mon Sep 17 00:00:00 2001 From: Gallasko <gallasko@gmail.com> Date: Fri, 4 Apr 2025 16:50:31 +0200 Subject: [PATCH] feat: Softmax works with any number of dimensions --- aidge_export_cpp/kernels/softmax.hpp | 66 +++++++++---------- aidge_export_cpp/operators.py | 24 +++++++ .../configuration/softmax_config.jinja | 4 +- .../kernel_forward/softmax_forward.jinja | 12 ++-- aidge_export_cpp/unit_tests/test_export.py | 24 +++++++ 5 files changed, 87 insertions(+), 43 deletions(-) diff --git a/aidge_export_cpp/kernels/softmax.hpp b/aidge_export_cpp/kernels/softmax.hpp index 73d00da..f5472cf 100644 --- a/aidge_export_cpp/kernels/softmax.hpp +++ b/aidge_export_cpp/kernels/softmax.hpp @@ -6,50 +6,48 @@ #include "kernels/macs.hpp" #include <type_traits> - #include <cmath> +#include <algorithm> -template<int NB_CHANNELS, - int CHANNELS_HEIGHT, int CHANNELS_WIDTH, - int NB_OUTPUTS, - int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH, - int AXIS, +template<int AXIS_SIZE, + int AXIS_SIZE_POST, + int AXIS_SIZE_PRE, typename Input_T, typename Output_T> __attribute__((always_inline)) inline void softmax_forward ( const Input_T* __restrict inputs, Output_T* __restrict outputs) { - Input_T maxValue = 0.0f; - - for (int och = 0; och < NB_OUTPUTS; och++) { - maxValue = std::max(maxValue, inputs[och]); - } - - Input_T sumExp = 0.0f; - - if constexpr (std::is_same_v<Input_T, Output_T>) { - for (int och = 0; och < NB_OUTPUTS; och++) { - // This should be both more performant while keeping the same memory footprint but we can only use it if INPUT_T and OUTPUT_T types are the same ! - outputs[och] = std::exp(inputs[och] - maxValue); - sumExp += outputs[och]; - } - - for (int och = 0; och < NB_OUTPUTS; och++) { - outputs[och] /= sumExp; - } - } - else - { - for (int och = 0; och < NB_OUTPUTS; och++) { - sumExp += std::exp(inputs[och] - maxValue); - } - - for (int och = 0; och < NB_OUTPUTS; och++) { - outputs[och] = std::exp(inputs[och] - maxValue) / sumExp; + // Iterate over the "pre-axis" and "post-axis" slices. + // For each slice along the axis, compute the maximum value, + // the sum of exponentials, and then write the normalized softmax outputs. + for (int i = 0; i < AXIS_SIZE_PRE; ++i) { + for (int j = 0; j < AXIS_SIZE_POST; ++j) { + // Compute the base index for this slice. + const int baseIdx = i * AXIS_SIZE * AXIS_SIZE_POST + j; + + // Find the maximum value along the axis. + Input_T maxVal = inputs[baseIdx]; + for (int k = 1; k < AXIS_SIZE; ++k) { + const int idx = baseIdx + k * AXIS_SIZE_POST; + maxVal = std::max(maxVal, inputs[idx]); + } + + // Compute the sum of the exponentials along the axis. + Input_T sumExp = 0; + for (int k = 0; k < AXIS_SIZE; ++k) { + const int idx = baseIdx + k * AXIS_SIZE_POST; + outputs[idx] = std::exp(inputs[idx] - maxVal); + sumExp += outputs[idx]; + } + + // Write the softmax values to the output. + for (int k = 0; k < AXIS_SIZE; ++k) { + const int idx = baseIdx + k * AXIS_SIZE_POST; + outputs[idx] /= sumExp; + } } } } - #endif // __AIDGE_EXPORT_CPP_KERNELS_SOFTMAX__ diff --git a/aidge_export_cpp/operators.py b/aidge_export_cpp/operators.py index a6ad95d..7c22cdb 100644 --- a/aidge_export_cpp/operators.py +++ b/aidge_export_cpp/operators.py @@ -338,6 +338,30 @@ class SoftmaxCPP(ExportNodeCpp): def __init__(self, node, mem_info): super().__init__(node, mem_info) self.attributes["axis"] = node.get_operator().attr.axis + + assert self.node.get_nb_inputs() == 1, ( + f"export softmax: nb_inputs == {self.node.get_nb_inputs()} not implemented" + ) + + tensor = self.operator.get_input(0) + nbDims = len(tensor.dims()) + + assert self.attributes["axis"] < nbDims, ( + f"export softmax: attribute axis == {node.get_operator().attr.axis} should be less than {nbDims}" + ) + + postAxisElems = 1 + for i in range(self.attributes["axis"] + 1, nbDims): + postAxisElems *= tensor.dims()[i] + + preAxisElems = 1 + for i in range(self.attributes["axis"]): + preAxisElems *= tensor.dims()[i] + + self.attributes["axis_size"] = tensor.dims()[self.attributes["axis"]] + self.attributes["axis_size_post"] = postAxisElems + self.attributes["axis_size_pre"] = preAxisElems + self.config_template = str( ROOT / "templates" / "configuration" / "softmax_config.jinja") self.forward_template = str( diff --git a/aidge_export_cpp/templates/configuration/softmax_config.jinja b/aidge_export_cpp/templates/configuration/softmax_config.jinja index d8ec8af..e9661bc 100644 --- a/aidge_export_cpp/templates/configuration/softmax_config.jinja +++ b/aidge_export_cpp/templates/configuration/softmax_config.jinja @@ -7,6 +7,8 @@ {#- Calculate sizes #} {%- set weights_size = out_chan[0] * in_chan[0] * in_height[0] * in_width[0] %} -#define {{ name|upper }}_AXIS {{ axis }} +#define {{ name|upper }}_AXIS_SIZE {{ axis_size }} +#define {{ name|upper }}_AXIS_SIZE_POST {{ axis_size_post }} +#define {{ name|upper }}_AXIS_SIZE_PRE {{ axis_size_pre }} #endif /* {{ name|upper }}_LAYER_H */ diff --git a/aidge_export_cpp/templates/kernel_forward/softmax_forward.jinja b/aidge_export_cpp/templates/kernel_forward/softmax_forward.jinja index 607ad53..7c8e067 100644 --- a/aidge_export_cpp/templates/kernel_forward/softmax_forward.jinja +++ b/aidge_export_cpp/templates/kernel_forward/softmax_forward.jinja @@ -1,12 +1,8 @@ {% filter indent(width=4, first=False) %} {% include "./_mem_offset.jinja" %} -softmax_forward<{{ in_name[0]|upper }}_NB_CHANNELS, - {{ in_name[0]|upper }}_IN_HEIGHT, - {{ in_name[0]|upper }}_IN_WIDTH, - {{ out_name[0]|upper }}_NB_OUTPUTS, - {{ out_name[0]|upper }}_OUT_HEIGHT, - {{ out_name[0]|upper }}_OUT_WIDTH, - {{ name|upper }}_AXIS> - ({{in_name[0]}}, {{out_name[0]}}); +softmax_forward<{{ name|upper }}_AXIS_SIZE, + {{ name|upper }}_AXIS_SIZE_POST, + {{ name|upper }}_AXIS_SIZE_PRE> + ({{in_name[0]}}, {{out_name[0]}}); {% include "./_save_outputs.jinja" %} {% endfilter %} diff --git a/aidge_export_cpp/unit_tests/test_export.py b/aidge_export_cpp/unit_tests/test_export.py index d8e7814..3d55f11 100644 --- a/aidge_export_cpp/unit_tests/test_export.py +++ b/aidge_export_cpp/unit_tests/test_export.py @@ -169,6 +169,30 @@ class test_operator_export(unittest.TestCase): self.unit_test_export(model, "Softmax", [[1, 10]]) + def test_export_softmax_batch(self): + print("SoftmaxBatch") + model = aidge_core.sequential([ + aidge_core.Softmax(axis=1, name="sf0") + ]) + + self.unit_test_export(model, "SoftmaxBatch", [[3, 10]]) + + def test_export_softmax_axis_2(self): + print("SoftmaxAxis2") + model = aidge_core.sequential([ + aidge_core.Softmax(axis=2, name="sf0") + ]) + + self.unit_test_export(model, "SoftmaxAxis2", [[1, 10, 3, 7]]) + + def test_export_softmax_axis_0(self): + print("SoftmaxAxis0") + model = aidge_core.sequential([ + aidge_core.Softmax(axis=0, name="sf0") + ]) + + self.unit_test_export(model, "SoftmaxAxis0", [[10]]) + @unittest.skip("Currently this test is failing") def test_export_FC_image_in(self): """Test exporting a FC operator with a HWC input. -- GitLab