From 71e04c1b343a7630616ebd2383c63416c855007b Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 9 Aug 2024 17:19:57 +0200 Subject: [PATCH] Fixed Softmax impl --- .../cpu/operator/SoftmaxImpl_forward_kernels.hpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/include/aidge/backend/cpu/operator/SoftmaxImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/SoftmaxImpl_forward_kernels.hpp index cc384c38..6ff8b3dd 100644 --- a/include/aidge/backend/cpu/operator/SoftmaxImpl_forward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/SoftmaxImpl_forward_kernels.hpp @@ -39,17 +39,23 @@ void SoftmaxImpl_cpu_forward_kernel(std::size_t axisIdx, const std::vector<DimSi for (std::size_t i = 0; i < preAxisElems; ++i) { for (std::size_t j = 0; j < postAxisElems; ++j) { + I maxVal = input[i * inputDims[axisIdx] * postAxisElems + j]; + for (std::size_t k = 1; k < inputDims[axisIdx]; ++k) { + std::size_t inIdx = i * inputDims[axisIdx] * postAxisElems + k * postAxisElems + j; + maxVal = std::max(maxVal, input[inIdx]); + } + // Calculate sum of exponentials within the axis I sumExp = 0; for (std::size_t k = 0; k < inputDims[axisIdx]; ++k) { std::size_t inIdx = i * inputDims[axisIdx] * postAxisElems + k * postAxisElems + j; - sumExp += std::exp(input[inIdx]); + sumExp += std::exp(input[inIdx] - maxVal); } // Calculate softmax for the current slice along the axis for (std::size_t k = 0; k < inputDims[axisIdx]; ++k) { std::size_t inIdx = i * inputDims[axisIdx] * postAxisElems + k * postAxisElems + j; - output[inIdx] = std::exp(input[inIdx]) / sumExp; + output[inIdx] = std::exp(input[inIdx] - maxVal) / sumExp; } } } -- GitLab