diff --git a/aidge_export_cpp/kernels/softmax.hpp b/aidge_export_cpp/kernels/softmax.hpp index 73d00da058ac53c7c625ae66d65a9aead19559a4..5e2444a0b10d874c64bbb384555278b0208617f9 100644 --- a/aidge_export_cpp/kernels/softmax.hpp +++ b/aidge_export_cpp/kernels/softmax.hpp @@ -6,8 +6,8 @@ #include "kernels/macs.hpp" #include <type_traits> - #include <cmath> +#include <algorithm> template<int NB_CHANNELS, int CHANNELS_HEIGHT, int CHANNELS_WIDTH, @@ -20,36 +20,56 @@ void softmax_forward ( const Input_T* __restrict inputs, Output_T* __restrict outputs) { - Input_T maxValue = 0.0f; + // Todo those dims calculation cannot work as this operator can take an arbitrary number of dimensions + // This currently only works for axis 0 and 1 but to make it work correctly we need to pass the dims and dims size as + // arguments to the function + constexpr int nDims = 3; + constexpr int dims[3] = {NB_CHANNELS, NB_OUTPUTS, CHANNELS_WIDTH}; + constexpr int axisIdx = (AXIS < 0) ? AXIS + nDims : AXIS; + constexpr int preAxisElems = + (axisIdx == 0) ? 1 + : (axisIdx == 1) ? 1 + : (axisIdx == 2) ? NB_CHANNELS * CHANNELS_HEIGHT + : 0; // Should not occur if axisIdx is valid + constexpr int axisSize = dims[axisIdx]; + constexpr int postAxisElems = + (axisIdx == 2) ? 1 + : (axisIdx == 1) ? 1 + : (axisIdx == 0) ? CHANNELS_HEIGHT * CHANNELS_WIDTH + : 0; // Should not occur if axisIdx is valid - for (int och = 0; och < NB_OUTPUTS; och++) { - maxValue = std::max(maxValue, inputs[och]); - } - Input_T sumExp = 0.0f; - if constexpr (std::is_same_v<Input_T, Output_T>) { - for (int och = 0; och < NB_OUTPUTS; och++) { - // This should be both more performant while keeping the same memory footprint but we can only use it if INPUT_T and OUTPUT_T types are the same ! - outputs[och] = std::exp(inputs[och] - maxValue); - sumExp += outputs[och]; - } + // Iterate over the "pre-axis" and "post-axis" slices. + // For each slice along the axis, compute the maximum value, + // the sum of exponentials, and then write the normalized softmax outputs. + for (int i = 0; i < preAxisElems; ++i) { + for (int j = 0; j < postAxisElems; ++j) { + // Compute the base index for this slice. + const int baseIdx = i * axisSize * postAxisElems + j; - for (int och = 0; och < NB_OUTPUTS; och++) { - outputs[och] /= sumExp; - } - } - else - { - for (int och = 0; och < NB_OUTPUTS; och++) { - sumExp += std::exp(inputs[och] - maxValue); - } + // Find the maximum value along the axis. + Input_T maxVal = inputs[baseIdx]; + for (int k = 1; k < axisSize; ++k) { + const int idx = i * axisSize * postAxisElems + k * postAxisElems + j; + maxVal = std::max(maxVal, inputs[idx]); + } - for (int och = 0; och < NB_OUTPUTS; och++) { - outputs[och] = std::exp(inputs[och] - maxValue) / sumExp; + // Compute the sum of the exponentials along the axis. + Input_T sumExp = 0; + for (int k = 0; k < axisSize; ++k) { + const int idx = i * axisSize * postAxisElems + k * postAxisElems + j; + outputs[idx] = std::exp(inputs[idx] - maxVal); + sumExp += outputs[idx]; + } + + // Write the softmax values to the output. + for (int k = 0; k < axisSize; ++k) { + const int idx = i * axisSize * postAxisElems + k * postAxisElems + j; + outputs[idx] /= sumExp; + } } } } - #endif // __AIDGE_EXPORT_CPP_KERNELS_SOFTMAX__ diff --git a/aidge_export_cpp/operators.py b/aidge_export_cpp/operators.py index 90d039fe73c93a380fa85baa939a61b207e48fba..2fdb82bc91f8ca502fee9107edc7b6641bf9a33d 100644 --- a/aidge_export_cpp/operators.py +++ b/aidge_export_cpp/operators.py @@ -334,6 +334,11 @@ class SoftmaxCPP(ExportNodeCpp): def __init__(self, node, mem_info): super().__init__(node, mem_info) self.attributes["axis"] = node.get_operator().attr.axis + + assert self.attributes["axis"] == 0 or self.attributes["axis"] == 1, ( + f"export softmax: attribute axis == {node.get_operator().attr.axis} not implemented" + ) + self.config_template = str( ROOT / "templates" / "configuration" / "softmax_config.jinja") self.forward_template = str( diff --git a/aidge_export_cpp/unit_tests/test_export.py b/aidge_export_cpp/unit_tests/test_export.py index 4f9ea7a175ef238c8103081c92e78dcd8e663066..51eaf7bd59f91b21b8d94842843157d6acf1c5b9 100644 --- a/aidge_export_cpp/unit_tests/test_export.py +++ b/aidge_export_cpp/unit_tests/test_export.py @@ -166,6 +166,23 @@ class test_operator_export(unittest.TestCase): self.unit_test_export(model, "Softmax", [[1, 10]]) + @unittest.expectedFailure + def test_export_softmax_axis_2(self): + print("SoftmaxAxis2") + model = aidge_core.sequential([ + aidge_core.Softmax(axis=2, name="sf0") + ]) + + self.unit_test_export(model, "SoftmaxAxis2", [[1, 10, 3, 7]]) + + def test_export_softmax_axis_0(self): + print("SoftmaxAxis0") + model = aidge_core.sequential([ + aidge_core.Softmax(axis=0, name="sf0") + ]) + + self.unit_test_export(model, "SoftmaxAxis0", [[10]]) + @unittest.skip("Currently this test is failing") def test_export_FC_image_in(self): """Test exporting a FC operator with a HWC input.