From 9de5fc92a15f961cd113af90587a243d3efefa02 Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Wed, 25 Sep 2024 15:09:09 +0200 Subject: [PATCH] minor kernel cleanings --- .../backend/cpu/operator/PowImpl_kernels.hpp | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/include/aidge/backend/cpu/operator/PowImpl_kernels.hpp b/include/aidge/backend/cpu/operator/PowImpl_kernels.hpp index a89dc9ff..ab9b2ccc 100644 --- a/include/aidge/backend/cpu/operator/PowImpl_kernels.hpp +++ b/include/aidge/backend/cpu/operator/PowImpl_kernels.hpp @@ -31,14 +31,10 @@ void PowImpl_cpu_forward_kernel(const std::vector<std::size_t>& input1Dims, const I2* input_2 = static_cast<const I2*>(input2_); O* output = static_cast<O*>(output_); - size_t totalElements = 1; - for (size_t dimSize : outputDims) { - totalElements *= dimSize; - } - + std::size_t totalElements = std::accumulate(outputDims.cbegin(), outputDims.cend(), std::size_t(1), std::multiplies<std::size_t>()); for (std::size_t oIndex = 0; oIndex < totalElements; ++oIndex) { - std::vector<size_t> indexes = getMultiDimIndices(outputDims, oIndex); + std::vector<std::size_t> indexes = getMultiDimIndices(outputDims, oIndex); std::size_t idx1 = getFlattenedIndex(input1Dims, indexes); std::size_t idx2 = getFlattenedIndex(input2Dims, indexes); @@ -63,24 +59,24 @@ void PowImpl_cpu_backward_kernel(const std::vector<std::size_t>& input0Dims, const O* gradOut = static_cast<const O*>(gradOutput_); // Fill input grads with zeros - auto input0Elements = std::accumulate(input0Dims.cbegin(), input0Dims.cend(), std::size_t(1), std::multiplies<std::size_t>()); + std::size_t input0Elements = std::accumulate(input0Dims.cbegin(), input0Dims.cend(), std::size_t(1), std::multiplies<std::size_t>()); std::fill(grad0, grad0 + input0Elements, I1(0)); - auto input1Elements = std::accumulate(input1Dims.cbegin(), input1Dims.cend(), std::size_t(1), std::multiplies<std::size_t>()); - std::fill(grad1, grad1 + input1Elements, I1(0)); + std::size_t input1Elements = std::accumulate(input1Dims.cbegin(), input1Dims.cend(), std::size_t(1), std::multiplies<std::size_t>()); + std::fill(grad1, grad1 + input1Elements, I2(0)); - auto totalElements = std::accumulate(outputDims.cbegin(), outputDims.cend(), std::size_t(1), std::multiplies<std::size_t>()); - for (size_t i = 0; i < totalElements; ++i) + std::size_t totalElements = std::accumulate(outputDims.cbegin(), outputDims.cend(), std::size_t(1), std::multiplies<std::size_t>()); + for (size_t oIndex = 0; oIndex < totalElements; ++oIndex) { // Compute indexes in inputs 0 and 1 to support broadcasting - std::vector<std::size_t> indexes = getMultiDimIndices(outputDims, i); + std::vector<std::size_t> indexes = getMultiDimIndices(outputDims, oIndex); std::size_t idx0 = getFlattenedIndex(input0Dims, indexes); std::size_t idx1 = getFlattenedIndex(input1Dims, indexes); // grad0 = grad_output * (input1 * pow(input0, (input1 -1))) - grad0[idx0] += gradOut[i]*input1[idx1]* std::pow(input0[idx0], input1[idx1]-1); + grad0[idx0] += gradOut[oIndex]*input1[idx1]* std::pow(input0[idx0], input1[idx1]-1); // grad1 = grad_output * (output * ln(input0)) - grad1[idx1] += gradOut[i] * std::pow(input0[idx0], input1[idx1]) * std::log(input0[idx0]); + grad1[idx1] += gradOut[oIndex] * std::pow(input0[idx0], input1[idx1]) * std::log(input0[idx0]); } } -- GitLab