From f43121fa481cc48415430150dd6d9a7feff5af6c Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Mon, 23 Sep 2024 11:14:09 +0200 Subject: [PATCH] Fix Pow Backward kernel --- .../aidge/backend/cpu/operator/PowImpl.hpp | 2 +- .../cpu/operator/PowImpl_backward_kernels.hpp | 28 +++++++++++-------- src/operator/PowImpl.cpp | 11 ++++---- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/include/aidge/backend/cpu/operator/PowImpl.hpp b/include/aidge/backend/cpu/operator/PowImpl.hpp index 8d0c0831..747a9135 100644 --- a/include/aidge/backend/cpu/operator/PowImpl.hpp +++ b/include/aidge/backend/cpu/operator/PowImpl.hpp @@ -28,7 +28,7 @@ class PowImplForward_cpu : public Registrable<PowImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> { }; class PowImplBackward_cpu - : public Registrable<PowImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, std::size_t , const void*, const void*, const void*, const void*, void*, void*)> { + : public Registrable<PowImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, const void*, void*, void*)> { }; class PowImpl_cpu : public OperatorImpl { diff --git a/include/aidge/backend/cpu/operator/PowImpl_backward_kernels.hpp b/include/aidge/backend/cpu/operator/PowImpl_backward_kernels.hpp index f3483eac..ecc6ef48 100644 --- a/include/aidge/backend/cpu/operator/PowImpl_backward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/PowImpl_backward_kernels.hpp @@ -12,37 +12,41 @@ #ifndef AIDGE_CPU_OPERATOR_POWIMPL_BACKWARD_KERNEL_H_ #define AIDGE_CPU_OPERATOR_POWIMPL_BACKWARD_KERNEL_H_ -#include "aidge/utils/Registrar.hpp" #include <cmath> +#include <numeric> +#include <vector> #include "aidge/backend/cpu/data/Broadcasting.hpp" #include "aidge/backend/cpu/operator/PowImpl.hpp" -#include <iostream> -#include <vector> +#include "aidge/utils/Registrar.hpp" + namespace Aidge { template <class I1, class I2, class O> void PowImpl_cpu_backward_kernel(const std::vector<std::size_t>& input0Dims, const std::vector<std::size_t>& input1Dims, const std::vector<std::size_t>& outputDims, - std::size_t totalElements, - const void* input0_, - const void* input1_, - const void* output_, - const void* gradOutput_, + const void* input0_, + const void* input1_, + const void* gradOutput_, void* gradientInput0_, void* gradientInput1_) { const I1* input0 = static_cast<const I1*>(input0_); I1* grad0 = static_cast<I1*>(gradientInput0_); const I2* input1 = static_cast<const I2*>(input1_); I2* grad1 = static_cast<I2*>(gradientInput1_); - const O* output = static_cast<const O*>(output_); const O* gradOut = static_cast<const O*>(gradOutput_); - + + + auto totalElements = std::accumulate(outputDims.cbegin(), outputDims.cend(), std::size_t(1), std::multiplies<std::size_t>()); + auto input0Elements = std::accumulate(input0Dims.cbegin(), input0Dims.cend(), std::size_t(1), std::multiplies<std::size_t>()); + std::fill(grad0, grad0 + input0Elements, I1(0)); + auto input1Elements = std::accumulate(input1Dims.cbegin(), input1Dims.cend(), std::size_t(1), std::multiplies<std::size_t>()); + std::fill(grad1, grad1 + input1Elements, I1(0)); + for (size_t i = 0; i < totalElements; ++i) { std::vector<std::size_t> indexes = getMultiDimIndices(outputDims, i); - std::size_t idx0 = getFlattenedIndex(input0Dims, indexes); std::size_t idx1 = getFlattenedIndex(input1Dims, indexes); @@ -50,7 +54,7 @@ void PowImpl_cpu_backward_kernel(const std::vector<std::size_t>& input0Dims, grad0[idx0] += gradOut[i]*input1[idx1]* std::pow(input0[idx0], input1[idx1]-1); // grad1 = grad_output * (output * ln(input0)) - grad1[idx1] += gradOut[i]*output[i] * std::log(input0[idx0]); + grad1[idx1] += gradOut[i] * std::pow(input0[idx0], input1[idx1]) * std::log(input0[idx0]); } } diff --git a/src/operator/PowImpl.cpp b/src/operator/PowImpl.cpp index 97055e4b..dfdb8fcf 100644 --- a/src/operator/PowImpl.cpp +++ b/src/operator/PowImpl.cpp @@ -21,6 +21,7 @@ #include "aidge/backend/cpu/data/GetCPUPtr.h" #include "aidge/backend/cpu/operator/PowImpl.hpp" +#include "aidge/backend/cpu/operator/PowImpl_backward_kernels.hpp" #include "aidge/backend/cpu/operator/PowImpl_forward_kernels.hpp" Aidge::Elts_t Aidge::PowImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { @@ -57,17 +58,15 @@ void Aidge::PowImpl_cpu::backward() { op_.getInput(0)->grad()->dataType(), op_.getInput(1)->grad()->dataType()}); - const std::vector<std::size_t> input0gradDims = getBroadcastedDims(op_.getInput(0)->grad()->dims(), - op_.getOutput(0)->grad()->dims()); - const std::vector<std::size_t> input1gradDims = getBroadcastedDims(op_.getInput(1)->grad()->dims(), - op_.getOutput(0)->grad()->dims()); + const std::vector<std::size_t> input0gradDims = getBroadcastedDims(op_.getOutput(0)->grad()->dims(), + op_.getInput(0)->grad()->dims()); + const std::vector<std::size_t> input1gradDims = getBroadcastedDims(op_.getOutput(0)->grad()->dims(), + op_.getInput(1)->grad()->dims()); // Call kernel kernelFunc(input0gradDims, input1gradDims, op_.getOutput(0)->grad()->dims(), - op_.getOutput(0)->size(), - getCPUPtr(mOp.getRawOutput(0)), getCPUPtr(mOp.getRawInput(0)), getCPUPtr(mOp.getRawInput(1)), getCPUPtr(op_.getOutput(0)->grad()), -- GitLab