From f43121fa481cc48415430150dd6d9a7feff5af6c Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Mon, 23 Sep 2024 11:14:09 +0200
Subject: [PATCH] Fix Pow Backward kernel

---
 .../aidge/backend/cpu/operator/PowImpl.hpp    |  2 +-
 .../cpu/operator/PowImpl_backward_kernels.hpp | 28 +++++++++++--------
 src/operator/PowImpl.cpp                      | 11 ++++----
 3 files changed, 22 insertions(+), 19 deletions(-)

diff --git a/include/aidge/backend/cpu/operator/PowImpl.hpp b/include/aidge/backend/cpu/operator/PowImpl.hpp
index 8d0c0831..747a9135 100644
--- a/include/aidge/backend/cpu/operator/PowImpl.hpp
+++ b/include/aidge/backend/cpu/operator/PowImpl.hpp
@@ -28,7 +28,7 @@ class PowImplForward_cpu
     : public Registrable<PowImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> {
 };
 class PowImplBackward_cpu
-    : public Registrable<PowImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, std::size_t , const void*, const void*, const void*, const void*, void*, void*)> {
+    : public Registrable<PowImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, const void*, void*, void*)> {
 };
 
 class PowImpl_cpu : public OperatorImpl {
diff --git a/include/aidge/backend/cpu/operator/PowImpl_backward_kernels.hpp b/include/aidge/backend/cpu/operator/PowImpl_backward_kernels.hpp
index f3483eac..ecc6ef48 100644
--- a/include/aidge/backend/cpu/operator/PowImpl_backward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/PowImpl_backward_kernels.hpp
@@ -12,37 +12,41 @@
 #ifndef AIDGE_CPU_OPERATOR_POWIMPL_BACKWARD_KERNEL_H_
 #define AIDGE_CPU_OPERATOR_POWIMPL_BACKWARD_KERNEL_H_
 
-#include "aidge/utils/Registrar.hpp"
 #include <cmath>
+#include <numeric>
+#include <vector>
 
 #include "aidge/backend/cpu/data/Broadcasting.hpp"
 #include "aidge/backend/cpu/operator/PowImpl.hpp"
-#include <iostream>
-#include <vector>
+#include "aidge/utils/Registrar.hpp"
+
 
 namespace Aidge {
 template <class I1, class I2, class O>
 void PowImpl_cpu_backward_kernel(const std::vector<std::size_t>& input0Dims,
                                 const std::vector<std::size_t>& input1Dims,
                                 const std::vector<std::size_t>& outputDims,
-								std::size_t totalElements,
-                                const void* input0_, 
-                                const void* input1_, 
-                                const void* output_, 
-                                const void* gradOutput_, 
+                                const void* input0_,
+                                const void* input1_,
+                                const void* gradOutput_,
                                 void* gradientInput0_,
                                 void* gradientInput1_) {
 	const I1* input0 = static_cast<const I1*>(input0_);
 	I1* grad0 = static_cast<I1*>(gradientInput0_);
     const I2* input1 = static_cast<const I2*>(input1_);
     I2* grad1 = static_cast<I2*>(gradientInput1_);
-    const O* output = static_cast<const O*>(output_);
     const O* gradOut = static_cast<const O*>(gradOutput_);
-    
+
+
+	auto totalElements = std::accumulate(outputDims.cbegin(), outputDims.cend(), std::size_t(1), std::multiplies<std::size_t>());
+	auto input0Elements = std::accumulate(input0Dims.cbegin(), input0Dims.cend(), std::size_t(1), std::multiplies<std::size_t>());
+	std::fill(grad0, grad0 + input0Elements, I1(0));
+	auto input1Elements = std::accumulate(input1Dims.cbegin(), input1Dims.cend(), std::size_t(1), std::multiplies<std::size_t>());
+	std::fill(grad1, grad1 + input1Elements, I1(0));
+
     for (size_t i = 0; i < totalElements; ++i)
     {
         std::vector<std::size_t> indexes = getMultiDimIndices(outputDims, i);
-        
         std::size_t idx0 = getFlattenedIndex(input0Dims, indexes);
         std::size_t idx1 = getFlattenedIndex(input1Dims, indexes);
 
@@ -50,7 +54,7 @@ void PowImpl_cpu_backward_kernel(const std::vector<std::size_t>& input0Dims,
         grad0[idx0] += gradOut[i]*input1[idx1]* std::pow(input0[idx0], input1[idx1]-1);
 
         // grad1 = grad_output * (output * ln(input0))
-        grad1[idx1] += gradOut[i]*output[i] * std::log(input0[idx0]);
+        grad1[idx1] += gradOut[i] * std::pow(input0[idx0], input1[idx1]) * std::log(input0[idx0]);
     }
 }
 
diff --git a/src/operator/PowImpl.cpp b/src/operator/PowImpl.cpp
index 97055e4b..dfdb8fcf 100644
--- a/src/operator/PowImpl.cpp
+++ b/src/operator/PowImpl.cpp
@@ -21,6 +21,7 @@
 #include "aidge/backend/cpu/data/GetCPUPtr.h"
 
 #include "aidge/backend/cpu/operator/PowImpl.hpp"
+#include "aidge/backend/cpu/operator/PowImpl_backward_kernels.hpp"
 #include "aidge/backend/cpu/operator/PowImpl_forward_kernels.hpp"
 
 Aidge::Elts_t Aidge::PowImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
@@ -57,17 +58,15 @@ void Aidge::PowImpl_cpu::backward() {
         op_.getInput(0)->grad()->dataType(),
         op_.getInput(1)->grad()->dataType()});
 
-    const std::vector<std::size_t> input0gradDims = getBroadcastedDims(op_.getInput(0)->grad()->dims(),
-                                                                   op_.getOutput(0)->grad()->dims());
-    const std::vector<std::size_t> input1gradDims = getBroadcastedDims(op_.getInput(1)->grad()->dims(),
-                                                                   op_.getOutput(0)->grad()->dims());
+    const std::vector<std::size_t> input0gradDims = getBroadcastedDims(op_.getOutput(0)->grad()->dims(),
+                                                                       op_.getInput(0)->grad()->dims());
+    const std::vector<std::size_t> input1gradDims = getBroadcastedDims(op_.getOutput(0)->grad()->dims(),
+                                                                       op_.getInput(1)->grad()->dims());
 
     // Call kernel
     kernelFunc(input0gradDims,
                input1gradDims,
                op_.getOutput(0)->grad()->dims(),
-               op_.getOutput(0)->size(),
-               getCPUPtr(mOp.getRawOutput(0)),
                getCPUPtr(mOp.getRawInput(0)),
                getCPUPtr(mOp.getRawInput(1)),
                getCPUPtr(op_.getOutput(0)->grad()),
-- 
GitLab