Skip to content
Snippets Groups Projects
Commit 83b43f32 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'softmax' into 'dev'

Fixed Softmax impl

See merge request !76
parents 69c994fe 71e04c1b
No related branches found
No related tags found
2 merge requests!93Release v0.3.0,!76Fixed Softmax impl
Pipeline #53254 passed
...@@ -39,17 +39,23 @@ void SoftmaxImpl_cpu_forward_kernel(std::size_t axisIdx, const std::vector<DimSi ...@@ -39,17 +39,23 @@ void SoftmaxImpl_cpu_forward_kernel(std::size_t axisIdx, const std::vector<DimSi
for (std::size_t i = 0; i < preAxisElems; ++i) { for (std::size_t i = 0; i < preAxisElems; ++i) {
for (std::size_t j = 0; j < postAxisElems; ++j) { for (std::size_t j = 0; j < postAxisElems; ++j) {
I maxVal = input[i * inputDims[axisIdx] * postAxisElems + j];
for (std::size_t k = 1; k < inputDims[axisIdx]; ++k) {
std::size_t inIdx = i * inputDims[axisIdx] * postAxisElems + k * postAxisElems + j;
maxVal = std::max(maxVal, input[inIdx]);
}
// Calculate sum of exponentials within the axis // Calculate sum of exponentials within the axis
I sumExp = 0; I sumExp = 0;
for (std::size_t k = 0; k < inputDims[axisIdx]; ++k) { for (std::size_t k = 0; k < inputDims[axisIdx]; ++k) {
std::size_t inIdx = i * inputDims[axisIdx] * postAxisElems + k * postAxisElems + j; std::size_t inIdx = i * inputDims[axisIdx] * postAxisElems + k * postAxisElems + j;
sumExp += std::exp(input[inIdx]); sumExp += std::exp(input[inIdx] - maxVal);
} }
// Calculate softmax for the current slice along the axis // Calculate softmax for the current slice along the axis
for (std::size_t k = 0; k < inputDims[axisIdx]; ++k) { for (std::size_t k = 0; k < inputDims[axisIdx]; ++k) {
std::size_t inIdx = i * inputDims[axisIdx] * postAxisElems + k * postAxisElems + j; std::size_t inIdx = i * inputDims[axisIdx] * postAxisElems + k * postAxisElems + j;
output[inIdx] = std::exp(input[inIdx]) / sumExp; output[inIdx] = std::exp(input[inIdx] - maxVal) / sumExp;
} }
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment