From 05ca42758e7055f74f7d465a0b057f664e9c085c Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Fri, 26 Jan 2024 11:47:50 +0100
Subject: [PATCH] add broadcasting for Sub operator

---
 .../aidge/backend/cpu/operator/SubImpl.hpp    |  4 +-
 .../cpu/operator/SubImpl_forward_kernels.hpp  | 42 +++++++++----------
 src/operator/SubImpl.cpp                      | 11 ++++-
 3 files changed, 30 insertions(+), 27 deletions(-)

diff --git a/include/aidge/backend/cpu/operator/SubImpl.hpp b/include/aidge/backend/cpu/operator/SubImpl.hpp
index 2d4c22f0..b329ec6e 100644
--- a/include/aidge/backend/cpu/operator/SubImpl.hpp
+++ b/include/aidge/backend/cpu/operator/SubImpl.hpp
@@ -25,10 +25,10 @@ namespace Aidge {
 
 // compute kernel registry for forward and backward
 class SubImplForward_cpu
-    : public Registrable<SubImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const std::size_t, const void*, const void*,void*)> {
+    : public Registrable<SubImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> {
 };
 class SubImplBackward_cpu
-    : public Registrable<SubImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, const std::size_t, const void*, const void*, void*)> {
+    : public Registrable<SubImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, void*)> {
 };
 
 class SubImpl_cpu : public OperatorImpl {
diff --git a/include/aidge/backend/cpu/operator/SubImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/SubImpl_forward_kernels.hpp
index 08f2e24f..19b0bd21 100644
--- a/include/aidge/backend/cpu/operator/SubImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/SubImpl_forward_kernels.hpp
@@ -14,39 +14,35 @@
 
 #include "aidge/utils/Registrar.hpp"
 
+#include "aidge/backend/cpu/data/Broadcasting.hpp"
 #include "aidge/backend/cpu/operator/SubImpl.hpp"
 
+
 namespace Aidge {
 template <class I1, class I2, class O>
-void SubImpl_cpu_forward_kernel(std::size_t input1Length,
-                                     std::size_t input2Length,
-                                     const void* input1_,
-                                     const void* input2_,
-                                     void* output_) {
+void SubImpl_cpu_forward_kernel(const std::vector<std::size_t>& input1Dims,
+                                const std::vector<std::size_t>& input2Dims,
+                                const std::vector<std::size_t>& outputDims,
+                                const void* input1_,
+                                const void* input2_,
+                                void* output_) {
 
     const I1* input_1 = static_cast<const I1*>(input1_);
     const I2* input_2 = static_cast<const I2*>(input2_);
     O* output = static_cast<O*>(output_);
 
-    if (input2Length == input1Length)
-    {
-        for (std::size_t i = 0; i < input1Length; ++i) {
-            output[i] = input_1[i] - input_2[i];
-        }
-    }
-    else if (input2Length == 1)
-    {
-        for (std::size_t i = 0; i < input1Length; ++i) {
-            output[i] = input_1[i] - input_2[0];
-        }
-    }
-    else // input_2 is 1d and of size the number of channels of input_1
-    {
-        for (std::size_t i = 0; i < input1Length; ++i) {
-            std::size_t channelIdx = i % input2Length;
-            output[i] = input_1[i] - input_2[channelIdx];
-        }
+    size_t totalElements = 1;
+    for (size_t dimSize : outputDims) {
+        totalElements *= dimSize;
     }
+
+	for (std::size_t oIndex = 0; oIndex < totalElements; ++oIndex) 
+	{
+		std::vector<size_t> indexes = getMultiDimIndices(outputDims, oIndex);
+		std::size_t idx1 = getFlattenedIndex(input1Dims, indexes);
+		std::size_t idx2 = getFlattenedIndex(input2Dims, indexes);
+        output[oIndex] = input_1[idx1] - input_2[idx2];
+	}
 }
 
 namespace {
diff --git a/src/operator/SubImpl.cpp b/src/operator/SubImpl.cpp
index 038a1154..475f8cb8 100644
--- a/src/operator/SubImpl.cpp
+++ b/src/operator/SubImpl.cpp
@@ -17,6 +17,7 @@
 
 #include "aidge/operator/Sub.hpp"
 #include "aidge/utils/Types.h"
+#include "aidge/backend/cpu/data/Broadcasting.hpp"
 #include "aidge/backend/cpu/data/GetCPUPtr.h"
 
 #include "aidge/backend/cpu/operator/SubImpl.hpp"
@@ -35,9 +36,15 @@ void Aidge::SubImpl_cpu::forward() {
         std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(),
         std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
 
+    const std::vector<std::size_t> inputDims0 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(),
+                                                                   std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims());
+    const std::vector<std::size_t> inputDims1 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(),
+                                                                   std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims());
+
     // Call kernel
-    kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(),
-        std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size(),
+    kernelFunc(inputDims0,
+        inputDims1,
+        std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(),
         getCPUPtr(mOp.getRawInput(0)),
         getCPUPtr(mOp.getRawInput(1)),
         getCPUPtr(mOp.getRawOutput(0)));
-- 
GitLab