diff --git a/include/aidge/backend/cpu/operator/SoftmaxImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/SoftmaxImpl_forward_kernels.hpp index cc384c38e34d01887fc328d11de383aeef39fb8e..6ff8b3ddf39412aa6febdc188b7c27e8bfdcc178 100644 --- a/include/aidge/backend/cpu/operator/SoftmaxImpl_forward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/SoftmaxImpl_forward_kernels.hpp @@ -39,17 +39,23 @@ void SoftmaxImpl_cpu_forward_kernel(std::size_t axisIdx, const std::vector<DimSi for (std::size_t i = 0; i < preAxisElems; ++i) { for (std::size_t j = 0; j < postAxisElems; ++j) { + I maxVal = input[i * inputDims[axisIdx] * postAxisElems + j]; + for (std::size_t k = 1; k < inputDims[axisIdx]; ++k) { + std::size_t inIdx = i * inputDims[axisIdx] * postAxisElems + k * postAxisElems + j; + maxVal = std::max(maxVal, input[inIdx]); + } + // Calculate sum of exponentials within the axis I sumExp = 0; for (std::size_t k = 0; k < inputDims[axisIdx]; ++k) { std::size_t inIdx = i * inputDims[axisIdx] * postAxisElems + k * postAxisElems + j; - sumExp += std::exp(input[inIdx]); + sumExp += std::exp(input[inIdx] - maxVal); } // Calculate softmax for the current slice along the axis for (std::size_t k = 0; k < inputDims[axisIdx]; ++k) { std::size_t inIdx = i * inputDims[axisIdx] * postAxisElems + k * postAxisElems + j; - output[inIdx] = std::exp(input[inIdx]) / sumExp; + output[inIdx] = std::exp(input[inIdx] - maxVal) / sumExp; } } }