From 120d32f2b91d0755a5b45637387a462e2e97a1e2 Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Fri, 26 Jan 2024 11:47:18 +0100
Subject: [PATCH] add broadcasting for Pow operator

---
 .../aidge/backend/cpu/operator/PowImpl.hpp    |  4 +-
 .../cpu/operator/PowImpl_forward_kernels.hpp  | 43 +++++++++----------
 src/operator/PowImpl.cpp                      | 11 ++++-
 3 files changed, 31 insertions(+), 27 deletions(-)

diff --git a/include/aidge/backend/cpu/operator/PowImpl.hpp b/include/aidge/backend/cpu/operator/PowImpl.hpp
index d3cafa7e..c6e4cd36 100644
--- a/include/aidge/backend/cpu/operator/PowImpl.hpp
+++ b/include/aidge/backend/cpu/operator/PowImpl.hpp
@@ -25,10 +25,10 @@ namespace Aidge {
 
 // compute kernel registry for forward and backward
 class PowImplForward_cpu
-    : public Registrable<PowImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const std::size_t, const void*, const void*,void*)> {
+    : public Registrable<PowImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> {
 };
 class PowImplBackward_cpu
-    : public Registrable<PowImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const std::size_t, const void*, const void*, void*)> {
+    : public Registrable<PowImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, void*)> {
 };
 
 class PowImpl_cpu : public OperatorImpl {
diff --git a/include/aidge/backend/cpu/operator/PowImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/PowImpl_forward_kernels.hpp
index c9c5db7e..1146cfa7 100644
--- a/include/aidge/backend/cpu/operator/PowImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/PowImpl_forward_kernels.hpp
@@ -15,39 +15,36 @@
 #include "aidge/utils/Registrar.hpp"
 #include <cmath>
 
+#include "aidge/backend/cpu/data/Broadcasting.hpp"
 #include "aidge/backend/cpu/operator/PowImpl.hpp"
 
 namespace Aidge {
 template <class I1, class I2, class O>
-void PowImpl_cpu_forward_kernel(std::size_t input1Length,
-                                     std::size_t input2Length,
-                                     const void* input1_,
-                                     const void* input2_,
-                                     void* output_) {
+void PowImpl_cpu_forward_kernel(const std::vector<std::size_t>& input1Dims,
+                                const std::vector<std::size_t>& input2Dims,
+                                const std::vector<std::size_t>& outputDims,
+                                const void* input1_,
+                                const void* input2_,
+                                void* output_) {
 
     const I1* input_1 = static_cast<const I1*>(input1_);
     const I2* input_2 = static_cast<const I2*>(input2_);
     O* output = static_cast<O*>(output_);
 
-    if (input2Length == input1Length)
-    {
-        for (std::size_t i = 0; i < input1Length; ++i) {
-            output[i] = std::pow(input_1[i], input_2[i]);
-        }
-    }
-    else if (input2Length == 1)
-    {
-        for (std::size_t i = 0; i < input1Length; ++i) {
-            output[i] = std::pow(input_1[i], input_2[0]);
-        }
-    }
-    else // input_2 is 1d and of size the number of channels of input_1
-    {
-        for (std::size_t i = 0; i < input1Length; ++i) {
-            std::size_t channelIdx = i % input2Length;
-            output[i] = std::pow(input_1[i], input_2[channelIdx]);
-        }
+    size_t totalElements = 1;
+    for (size_t dimSize : outputDims) {
+        totalElements *= dimSize;
     }
+
+	for (std::size_t oIndex = 0; oIndex < totalElements; ++oIndex) 
+	{
+		std::vector<size_t> indexes = getMultiDimIndices(outputDims, oIndex);
+
+		std::size_t idx1 = getFlattenedIndex(input1Dims, indexes);
+		std::size_t idx2 = getFlattenedIndex(input2Dims, indexes);
+		
+        output[oIndex] = std::pow(input_1[idx1], input_2[idx2]);
+	}
 }
 
 namespace {
diff --git a/src/operator/PowImpl.cpp b/src/operator/PowImpl.cpp
index 49664640..22b4e27a 100644
--- a/src/operator/PowImpl.cpp
+++ b/src/operator/PowImpl.cpp
@@ -17,6 +17,7 @@
 
 #include "aidge/operator/Pow.hpp"
 #include "aidge/utils/Types.h"
+#include "aidge/backend/cpu/data/Broadcasting.hpp"
 #include "aidge/backend/cpu/data/GetCPUPtr.h"
 
 #include "aidge/backend/cpu/operator/PowImpl.hpp"
@@ -34,9 +35,15 @@ void Aidge::PowImpl_cpu::forward() {
         std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(),
         std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
 
+    const std::vector<std::size_t> inputDims0 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(),
+                                                                   std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims());
+    const std::vector<std::size_t> inputDims1 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(),
+                                                                   std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims());
+
     // Call kernel
-    kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(),
-        std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size(),
+    kernelFunc(inputDims0,
+        inputDims1,
+        std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(),
         getCPUPtr(mOp.getRawInput(0)),
         getCPUPtr(mOp.getRawInput(1)),
         getCPUPtr(mOp.getRawOutput(0)));
-- 
GitLab