Skip to content
Snippets Groups Projects

Fixed Softmax impl

Merged Olivier BICHLER requested to merge softmax into dev
1 file
+ 8
2
Compare changes
  • Side-by-side
  • Inline
@@ -39,17 +39,23 @@ void SoftmaxImpl_cpu_forward_kernel(std::size_t axisIdx, const std::vector<DimSi
for (std::size_t i = 0; i < preAxisElems; ++i) {
for (std::size_t j = 0; j < postAxisElems; ++j) {
I maxVal = input[i * inputDims[axisIdx] * postAxisElems + j];
for (std::size_t k = 1; k < inputDims[axisIdx]; ++k) {
std::size_t inIdx = i * inputDims[axisIdx] * postAxisElems + k * postAxisElems + j;
maxVal = std::max(maxVal, input[inIdx]);
}
// Calculate sum of exponentials within the axis
I sumExp = 0;
for (std::size_t k = 0; k < inputDims[axisIdx]; ++k) {
std::size_t inIdx = i * inputDims[axisIdx] * postAxisElems + k * postAxisElems + j;
sumExp += std::exp(input[inIdx]);
sumExp += std::exp(input[inIdx] - maxVal);
}
// Calculate softmax for the current slice along the axis
for (std::size_t k = 0; k < inputDims[axisIdx]; ++k) {
std::size_t inIdx = i * inputDims[axisIdx] * postAxisElems + k * postAxisElems + j;
output[inIdx] = std::exp(input[inIdx]) / sumExp;
output[inIdx] = std::exp(input[inIdx] - maxVal) / sumExp;
}
}
}
Loading