Skip to content
Snippets Groups Projects
Commit bd212be1 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

minor kernel cleanups

parent 9d80d4f2
No related branches found
No related tags found
Loading
This commit is part of merge request !90. Comments created here will be created in the context of that merge request.
...@@ -37,20 +37,21 @@ void PowImpl_cpu_backward_kernel(const std::vector<std::size_t>& input0Dims, ...@@ -37,20 +37,21 @@ void PowImpl_cpu_backward_kernel(const std::vector<std::size_t>& input0Dims,
I2* grad1 = static_cast<I2*>(gradientInput1_); I2* grad1 = static_cast<I2*>(gradientInput1_);
const O* gradOut = static_cast<const O*>(gradOutput_); const O* gradOut = static_cast<const O*>(gradOutput_);
// Fill input grads with zeros
auto totalElements = std::accumulate(outputDims.cbegin(), outputDims.cend(), std::size_t(1), std::multiplies<std::size_t>());
auto input0Elements = std::accumulate(input0Dims.cbegin(), input0Dims.cend(), std::size_t(1), std::multiplies<std::size_t>()); auto input0Elements = std::accumulate(input0Dims.cbegin(), input0Dims.cend(), std::size_t(1), std::multiplies<std::size_t>());
std::fill(grad0, grad0 + input0Elements, I1(0)); std::fill(grad0, grad0 + input0Elements, I1(0));
auto input1Elements = std::accumulate(input1Dims.cbegin(), input1Dims.cend(), std::size_t(1), std::multiplies<std::size_t>()); auto input1Elements = std::accumulate(input1Dims.cbegin(), input1Dims.cend(), std::size_t(1), std::multiplies<std::size_t>());
std::fill(grad1, grad1 + input1Elements, I1(0)); std::fill(grad1, grad1 + input1Elements, I1(0));
auto totalElements = std::accumulate(outputDims.cbegin(), outputDims.cend(), std::size_t(1), std::multiplies<std::size_t>());
for (size_t i = 0; i < totalElements; ++i) for (size_t i = 0; i < totalElements; ++i)
{ {
// Compute indexes in inputs 0 and 1 to support broadcasting
std::vector<std::size_t> indexes = getMultiDimIndices(outputDims, i); std::vector<std::size_t> indexes = getMultiDimIndices(outputDims, i);
std::size_t idx0 = getFlattenedIndex(input0Dims, indexes); std::size_t idx0 = getFlattenedIndex(input0Dims, indexes);
std::size_t idx1 = getFlattenedIndex(input1Dims, indexes); std::size_t idx1 = getFlattenedIndex(input1Dims, indexes);
// grad0 = input1 * pow (input0, (input1 -1)) // grad0 = grad_output * (input1 * pow(input0, (input1 -1)))
grad0[idx0] += gradOut[i]*input1[idx1]* std::pow(input0[idx0], input1[idx1]-1); grad0[idx0] += gradOut[i]*input1[idx1]* std::pow(input0[idx0], input1[idx1]-1);
// grad1 = grad_output * (output * ln(input0)) // grad1 = grad_output * (output * ln(input0))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment