Skip to content
Snippets Groups Projects
Commit bd212be1 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

minor kernel cleanups

parent 9d80d4f2
No related branches found
No related tags found
No related merge requests found
Pipeline #55283 passed
......@@ -37,20 +37,21 @@ void PowImpl_cpu_backward_kernel(const std::vector<std::size_t>& input0Dims,
I2* grad1 = static_cast<I2*>(gradientInput1_);
const O* gradOut = static_cast<const O*>(gradOutput_);
auto totalElements = std::accumulate(outputDims.cbegin(), outputDims.cend(), std::size_t(1), std::multiplies<std::size_t>());
// Fill input grads with zeros
auto input0Elements = std::accumulate(input0Dims.cbegin(), input0Dims.cend(), std::size_t(1), std::multiplies<std::size_t>());
std::fill(grad0, grad0 + input0Elements, I1(0));
auto input1Elements = std::accumulate(input1Dims.cbegin(), input1Dims.cend(), std::size_t(1), std::multiplies<std::size_t>());
std::fill(grad1, grad1 + input1Elements, I1(0));
auto totalElements = std::accumulate(outputDims.cbegin(), outputDims.cend(), std::size_t(1), std::multiplies<std::size_t>());
for (size_t i = 0; i < totalElements; ++i)
{
// Compute indexes in inputs 0 and 1 to support broadcasting
std::vector<std::size_t> indexes = getMultiDimIndices(outputDims, i);
std::size_t idx0 = getFlattenedIndex(input0Dims, indexes);
std::size_t idx1 = getFlattenedIndex(input1Dims, indexes);
// grad0 = input1 * pow (input0, (input1 -1))
// grad0 = grad_output * (input1 * pow(input0, (input1 -1)))
grad0[idx0] += gradOut[i]*input1[idx1]* std::pow(input0[idx0], input1[idx1]-1);
// grad1 = grad_output * (output * ln(input0))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment