Skip to content
Snippets Groups Projects
Commit 937b8fe6 authored by Gallas Gaye's avatar Gallas Gaye
Browse files

Pass on softmax op

parent 2131467a
No related branches found
No related tags found
No related merge requests found
Pipeline #69935 failed
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
#include "kernels/macs.hpp" #include "kernels/macs.hpp"
#include <type_traits> #include <type_traits>
#include <cmath> #include <cmath>
#include <algorithm>
template<int NB_CHANNELS, template<int NB_CHANNELS,
int CHANNELS_HEIGHT, int CHANNELS_WIDTH, int CHANNELS_HEIGHT, int CHANNELS_WIDTH,
...@@ -20,36 +20,56 @@ void softmax_forward ( ...@@ -20,36 +20,56 @@ void softmax_forward (
const Input_T* __restrict inputs, const Input_T* __restrict inputs,
Output_T* __restrict outputs) Output_T* __restrict outputs)
{ {
Input_T maxValue = 0.0f; // Todo those dims calculation cannot work as this operator can take an arbitrary number of dimensions
// This currently only works for axis 0 and 1 but to make it work correctly we need to pass the dims and dims size as
// arguments to the function
constexpr int nDims = 3;
constexpr int dims[3] = {NB_CHANNELS, NB_OUTPUTS, CHANNELS_WIDTH};
constexpr int axisIdx = (AXIS < 0) ? AXIS + nDims : AXIS;
constexpr int preAxisElems =
(axisIdx == 0) ? 1
: (axisIdx == 1) ? 1
: (axisIdx == 2) ? NB_CHANNELS * CHANNELS_HEIGHT
: 0; // Should not occur if axisIdx is valid
constexpr int axisSize = dims[axisIdx];
constexpr int postAxisElems =
(axisIdx == 2) ? 1
: (axisIdx == 1) ? 1
: (axisIdx == 0) ? CHANNELS_HEIGHT * CHANNELS_WIDTH
: 0; // Should not occur if axisIdx is valid
for (int och = 0; och < NB_OUTPUTS; och++) {
maxValue = std::max(maxValue, inputs[och]);
}
Input_T sumExp = 0.0f;
if constexpr (std::is_same_v<Input_T, Output_T>) { // Iterate over the "pre-axis" and "post-axis" slices.
for (int och = 0; och < NB_OUTPUTS; och++) { // For each slice along the axis, compute the maximum value,
// This should be both more performant while keeping the same memory footprint but we can only use it if INPUT_T and OUTPUT_T types are the same ! // the sum of exponentials, and then write the normalized softmax outputs.
outputs[och] = std::exp(inputs[och] - maxValue); for (int i = 0; i < preAxisElems; ++i) {
sumExp += outputs[och]; for (int j = 0; j < postAxisElems; ++j) {
} // Compute the base index for this slice.
const int baseIdx = i * axisSize * postAxisElems + j;
for (int och = 0; och < NB_OUTPUTS; och++) { // Find the maximum value along the axis.
outputs[och] /= sumExp; Input_T maxVal = inputs[baseIdx];
} for (int k = 1; k < axisSize; ++k) {
} const int idx = i * axisSize * postAxisElems + k * postAxisElems + j;
else maxVal = std::max(maxVal, inputs[idx]);
{ }
for (int och = 0; och < NB_OUTPUTS; och++) {
sumExp += std::exp(inputs[och] - maxValue);
}
for (int och = 0; och < NB_OUTPUTS; och++) { // Compute the sum of the exponentials along the axis.
outputs[och] = std::exp(inputs[och] - maxValue) / sumExp; Input_T sumExp = 0;
for (int k = 0; k < axisSize; ++k) {
const int idx = i * axisSize * postAxisElems + k * postAxisElems + j;
outputs[idx] = std::exp(inputs[idx] - maxVal);
sumExp += outputs[idx];
}
// Write the softmax values to the output.
for (int k = 0; k < axisSize; ++k) {
const int idx = i * axisSize * postAxisElems + k * postAxisElems + j;
outputs[idx] /= sumExp;
}
} }
} }
} }
#endif // __AIDGE_EXPORT_CPP_KERNELS_SOFTMAX__ #endif // __AIDGE_EXPORT_CPP_KERNELS_SOFTMAX__
...@@ -334,6 +334,11 @@ class SoftmaxCPP(ExportNodeCpp): ...@@ -334,6 +334,11 @@ class SoftmaxCPP(ExportNodeCpp):
def __init__(self, node, mem_info): def __init__(self, node, mem_info):
super().__init__(node, mem_info) super().__init__(node, mem_info)
self.attributes["axis"] = node.get_operator().attr.axis self.attributes["axis"] = node.get_operator().attr.axis
assert self.attributes["axis"] == 0 or self.attributes["axis"] == 1, (
f"export softmax: attribute axis == {node.get_operator().attr.axis} not implemented"
)
self.config_template = str( self.config_template = str(
ROOT / "templates" / "configuration" / "softmax_config.jinja") ROOT / "templates" / "configuration" / "softmax_config.jinja")
self.forward_template = str( self.forward_template = str(
......
...@@ -166,6 +166,23 @@ class test_operator_export(unittest.TestCase): ...@@ -166,6 +166,23 @@ class test_operator_export(unittest.TestCase):
self.unit_test_export(model, "Softmax", [[1, 10]]) self.unit_test_export(model, "Softmax", [[1, 10]])
@unittest.expectedFailure
def test_export_softmax_axis_2(self):
print("SoftmaxAxis2")
model = aidge_core.sequential([
aidge_core.Softmax(axis=2, name="sf0")
])
self.unit_test_export(model, "SoftmaxAxis2", [[1, 10, 3, 7]])
def test_export_softmax_axis_0(self):
print("SoftmaxAxis0")
model = aidge_core.sequential([
aidge_core.Softmax(axis=0, name="sf0")
])
self.unit_test_export(model, "SoftmaxAxis0", [[10]])
@unittest.skip("Currently this test is failing") @unittest.skip("Currently this test is failing")
def test_export_FC_image_in(self): def test_export_FC_image_in(self):
"""Test exporting a FC operator with a HWC input. """Test exporting a FC operator with a HWC input.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment