From 71e04c1b343a7630616ebd2383c63416c855007b Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Fri, 9 Aug 2024 17:19:57 +0200
Subject: [PATCH] Fixed Softmax impl

---
 .../cpu/operator/SoftmaxImpl_forward_kernels.hpp       | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/include/aidge/backend/cpu/operator/SoftmaxImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/SoftmaxImpl_forward_kernels.hpp
index cc384c38..6ff8b3dd 100644
--- a/include/aidge/backend/cpu/operator/SoftmaxImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/SoftmaxImpl_forward_kernels.hpp
@@ -39,17 +39,23 @@ void SoftmaxImpl_cpu_forward_kernel(std::size_t axisIdx, const std::vector<DimSi
 
     for (std::size_t i = 0; i < preAxisElems; ++i) {
         for (std::size_t j = 0; j < postAxisElems; ++j) {
+            I maxVal = input[i * inputDims[axisIdx] * postAxisElems + j];
+            for (std::size_t k = 1; k < inputDims[axisIdx]; ++k) {
+                std::size_t inIdx = i * inputDims[axisIdx] * postAxisElems + k * postAxisElems + j;
+                maxVal = std::max(maxVal, input[inIdx]);
+            }
+
             // Calculate sum of exponentials within the axis
             I sumExp = 0;
             for (std::size_t k = 0; k < inputDims[axisIdx]; ++k) {
                 std::size_t inIdx = i * inputDims[axisIdx] * postAxisElems + k * postAxisElems + j;
-                sumExp += std::exp(input[inIdx]);
+                sumExp += std::exp(input[inIdx] - maxVal);
             }
 
             // Calculate softmax for the current slice along the axis
             for (std::size_t  k = 0; k < inputDims[axisIdx]; ++k) {
                 std::size_t inIdx = i * inputDims[axisIdx] * postAxisElems + k * postAxisElems + j;
-                output[inIdx] = std::exp(input[inIdx]) / sumExp;
+                output[inIdx] = std::exp(input[inIdx] - maxVal) / sumExp;
             }
         }
     }
-- 
GitLab