Skip to content
Snippets Groups Projects
Commit 9de5fc92 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

minor kernel cleanings

parent 30083903
No related branches found
No related tags found
No related merge requests found
...@@ -31,14 +31,10 @@ void PowImpl_cpu_forward_kernel(const std::vector<std::size_t>& input1Dims, ...@@ -31,14 +31,10 @@ void PowImpl_cpu_forward_kernel(const std::vector<std::size_t>& input1Dims,
const I2* input_2 = static_cast<const I2*>(input2_); const I2* input_2 = static_cast<const I2*>(input2_);
O* output = static_cast<O*>(output_); O* output = static_cast<O*>(output_);
size_t totalElements = 1; std::size_t totalElements = std::accumulate(outputDims.cbegin(), outputDims.cend(), std::size_t(1), std::multiplies<std::size_t>());
for (size_t dimSize : outputDims) {
totalElements *= dimSize;
}
for (std::size_t oIndex = 0; oIndex < totalElements; ++oIndex) for (std::size_t oIndex = 0; oIndex < totalElements; ++oIndex)
{ {
std::vector<size_t> indexes = getMultiDimIndices(outputDims, oIndex); std::vector<std::size_t> indexes = getMultiDimIndices(outputDims, oIndex);
std::size_t idx1 = getFlattenedIndex(input1Dims, indexes); std::size_t idx1 = getFlattenedIndex(input1Dims, indexes);
std::size_t idx2 = getFlattenedIndex(input2Dims, indexes); std::size_t idx2 = getFlattenedIndex(input2Dims, indexes);
...@@ -63,24 +59,24 @@ void PowImpl_cpu_backward_kernel(const std::vector<std::size_t>& input0Dims, ...@@ -63,24 +59,24 @@ void PowImpl_cpu_backward_kernel(const std::vector<std::size_t>& input0Dims,
const O* gradOut = static_cast<const O*>(gradOutput_); const O* gradOut = static_cast<const O*>(gradOutput_);
// Fill input grads with zeros // Fill input grads with zeros
auto input0Elements = std::accumulate(input0Dims.cbegin(), input0Dims.cend(), std::size_t(1), std::multiplies<std::size_t>()); std::size_t input0Elements = std::accumulate(input0Dims.cbegin(), input0Dims.cend(), std::size_t(1), std::multiplies<std::size_t>());
std::fill(grad0, grad0 + input0Elements, I1(0)); std::fill(grad0, grad0 + input0Elements, I1(0));
auto input1Elements = std::accumulate(input1Dims.cbegin(), input1Dims.cend(), std::size_t(1), std::multiplies<std::size_t>()); std::size_t input1Elements = std::accumulate(input1Dims.cbegin(), input1Dims.cend(), std::size_t(1), std::multiplies<std::size_t>());
std::fill(grad1, grad1 + input1Elements, I1(0)); std::fill(grad1, grad1 + input1Elements, I2(0));
auto totalElements = std::accumulate(outputDims.cbegin(), outputDims.cend(), std::size_t(1), std::multiplies<std::size_t>()); std::size_t totalElements = std::accumulate(outputDims.cbegin(), outputDims.cend(), std::size_t(1), std::multiplies<std::size_t>());
for (size_t i = 0; i < totalElements; ++i) for (size_t oIndex = 0; oIndex < totalElements; ++oIndex)
{ {
// Compute indexes in inputs 0 and 1 to support broadcasting // Compute indexes in inputs 0 and 1 to support broadcasting
std::vector<std::size_t> indexes = getMultiDimIndices(outputDims, i); std::vector<std::size_t> indexes = getMultiDimIndices(outputDims, oIndex);
std::size_t idx0 = getFlattenedIndex(input0Dims, indexes); std::size_t idx0 = getFlattenedIndex(input0Dims, indexes);
std::size_t idx1 = getFlattenedIndex(input1Dims, indexes); std::size_t idx1 = getFlattenedIndex(input1Dims, indexes);
// grad0 = grad_output * (input1 * pow(input0, (input1 -1))) // grad0 = grad_output * (input1 * pow(input0, (input1 -1)))
grad0[idx0] += gradOut[i]*input1[idx1]* std::pow(input0[idx0], input1[idx1]-1); grad0[idx0] += gradOut[oIndex]*input1[idx1]* std::pow(input0[idx0], input1[idx1]-1);
// grad1 = grad_output * (output * ln(input0)) // grad1 = grad_output * (output * ln(input0))
grad1[idx1] += gradOut[i] * std::pow(input0[idx0], input1[idx1]) * std::log(input0[idx0]); grad1[idx1] += gradOut[oIndex] * std::pow(input0[idx0], input1[idx1]) * std::log(input0[idx0]);
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment