From f1813259f46a25a73f45fd0740a2e6ad43c6096a Mon Sep 17 00:00:00 2001
From: Adam Maroni <adamaroni@hotmail.fr>
Date: Sun, 23 Mar 2025 18:56:49 +0100
Subject: [PATCH] Refactoring of MaxPoolingImpl_kernels.hpp

---
 .../cpu/operator/MaxPoolingImpl_kernels.hpp   | 147 +++++++++++-------
 1 file changed, 87 insertions(+), 60 deletions(-)

diff --git a/include/aidge/backend/cpu/operator/MaxPoolingImpl_kernels.hpp b/include/aidge/backend/cpu/operator/MaxPoolingImpl_kernels.hpp
index 027fc02a..21eefb02 100644
--- a/include/aidge/backend/cpu/operator/MaxPoolingImpl_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/MaxPoolingImpl_kernels.hpp
@@ -14,6 +14,7 @@
 
 #include <array>
 #include <cmath>
+#include <cstdint>
 #include <tuple>
 
 
@@ -34,75 +35,101 @@ namespace Aidge {
  * @param output_ Output Tensor.
  */
 template <class I, class O>
-void MaxPoolingImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims,
-                                        const std::array<DimSize_t, 2>& kernelDims,
-                                        const std::array<DimSize_t, 2>& dilations,
-                                        const bool ceilMode,
-                                        const std::array<DimSize_t, 4> &dims,
-                                        const void *input_,
-                                        void *output_) {
-    const I *input = static_cast<const I *>(input_);
-    O *output = static_cast<O *>(output_);
+void MaxPoolingImpl2D_cpu_forward_kernel(
+  const std::array<DimSize_t, 2>& strideDims,
+  const std::array<DimSize_t, 2>& kernelDims,
+  const std::array<DimSize_t, 2>& dilations,
+  const bool ceilMode,
+  const std::array<DimSize_t, 4> &dims,
+  const void *input_,
+  void *output_)
+{
+  const I *input = static_cast<const I *>(input_);
+  O *output = static_cast<O *>(output_);
 
-    // output H size
-    const std::size_t oxSize = 
-        ceilMode 
-        ? static_cast<std::size_t>(std::ceil(static_cast<float>(dims[2] - (kernelDims[0] - 1) * dilations[0] - 1 + strideDims[0]) /
-                                            static_cast<float>(strideDims[0])))
-        : static_cast<std::size_t>(std::floor(static_cast<float>(dims[2] - (kernelDims[0] - 1) * dilations[0] - 1 + strideDims[0]) /
-                                            static_cast<float>(strideDims[0])));
-    // output W size
-    const std::size_t oySize = 
-        ceilMode 
-        ? static_cast<std::size_t>(std::ceil(static_cast<float>(dims[3] - (kernelDims[1] - 1) * dilations[1] - 1 + strideDims[1]) /
-                                            static_cast<float>(strideDims[1])))
-        : static_cast<std::size_t>(std::floor(static_cast<float>(dims[3] - (kernelDims[1] - 1) * dilations[1] - 1 + strideDims[1]) /
-                                            static_cast<float>(strideDims[1])));
+  // output H size
+  auto hOut = static_cast<float>(
+    dims[2] - (kernelDims[0] - 1) * dilations[0] - 1 + strideDims[0]
+  ) / static_cast<float>(strideDims[0]);
+  const std::size_t outXSize = ceilMode
+    ? static_cast<std::size_t>(std::ceil(hOut))
+    : static_cast<std::size_t>(std::floor(hOut));
 
-    using signedsize = std::make_signed<std::size_t>::type;
-    for (std::size_t batch = 0; batch < dims[0]; ++batch) {
-        for (std::size_t ch = 0; ch < dims[1]; ++ch) {
-            const std::size_t oIndex = (ch + batch*dims[1]) * oxSize * oySize;
-            const std::size_t iIndex = (ch + batch*dims[1]) * dims[2] * dims[3];
-            for (std::size_t ox = 0; ox < oxSize; ++ox) {
-                const signedsize difx = static_cast<signedsize>(- ox * strideDims[0]);
-                const std::size_t sxMin = static_cast<std::size_t>(std::max(difx, signedsize(0)));
-                const std::size_t sxMax = (static_cast<signedsize>(dims[2]) + difx) < 0 ? 0 : ((dims[2] + difx) > kernelDims[0] ? kernelDims[0] : dims[2] + difx);
-                for (std::size_t oy = 0; oy < oySize; ++oy) {
-                    const signedsize dify = static_cast<signedsize>(- oy * strideDims[1]);
-                    const std::size_t syMin = static_cast<std::size_t>(std::max(dify, signedsize(0)));
-                    const std::size_t syMax = (static_cast<signedsize>(dims[3]) + dify) < 0 ? 0 : ((dims[3] + dify) > kernelDims[1] ? kernelDims[1] : dims[3] + dify);
-                    const std::size_t oIndexFull = oIndex + ox*oySize + oy;
-                    const std::size_t ix = ox * strideDims[0];
-                    const std::size_t iy = oy * strideDims[1];
+  // output W size
+  auto wOut = static_cast<float>( 
+      dims[3] - ( kernelDims[1] - 1) * dilations[1] - 1 + strideDims[1]
+    ) / static_cast<float>(strideDims[1]);
 
-                    I poolValue(0.0);
-                    bool valid = false;
+  const std::size_t outYSize = ceilMode
+    ? static_cast<std::size_t>(std::ceil(wOut))
+    : static_cast<std::size_t>(std::floor(wOut));
 
-                    for (unsigned int sy = syMin; sy < syMax; ++sy) {
-                        for (unsigned int sx = sxMin; sx < sxMax; ++sx) {
-                            // Apply dilation factor to kernel indices
-                            const std::size_t dilated_sx = sx * dilations[0];
-                            const std::size_t dilated_sy = sy * dilations[1];
+  using signedsize = std::make_signed<std::size_t>::type;
 
-                            // Ensure indices are within bounds
-                            if ((ix + dilated_sx) < dims[2] && (iy + dilated_sy) < dims[3]) {
-                                const I value = input[iIndex + (ix + dilated_sx) * dims[3] + (iy + dilated_sy)];
-
-                                if (!valid || value > poolValue) {
-                                    poolValue = value;
-                                    valid = true;
-                                }
-                            }
-                        }
-                    }
-                    output[oIndexFull] = poolValue;
+  for (std::size_t batch = 0; batch < dims[0]; ++batch){
+    for (std::size_t channel = 0; channel < dims[1]; ++channel){
+      auto batchChannelIndex = (channel + batch * dims[1]);
+      const std::size_t outputBaseIndex = batchChannelIndex * outXSize * outYSize;
+      const std::size_t inputBaseIndex = batchChannelIndex * dims[2] * dims[3];
+      for (std::size_t outX = 0; outX < outXSize; ++outX) {
+        const signedsize negStrideX = static_cast<signedsize>(
+		-outX * strideDims[0]
+	);
+        const std::size_t kernelXMin = static_cast<std::size_t>(
+          std::max(negStrideX, signedsize(0))
+        );
+        /* Compute kernelXMax */
+        std::size_t kernelXMax = dims[2] + negStrideX;
+        if ((static_cast<signedsize>(dims[2]) + negStrideX) < 0){
+          kernelXMax = 0;
+        }
+        else if (kernelXMax > kernelDims[0]){
+          kernelXMax = kernelDims[0];
+        }
+        for (std::size_t outY = 0; outY < outYSize; ++outY) {
+          const signedsize negStrideY = static_cast<signedsize>(-outY * strideDims[1]);
+          const std::size_t kernelYMin = static_cast<std::size_t>(
+            std::max(negStrideY, signedsize(0))
+          );
+          /* Compute kernelYMax */
+          std::size_t kernelYMax = dims[3] + negStrideY;
+          const std::size_t outputIndex = outputBaseIndex + outX * outYSize + outY;
+          const std::size_t strideXoffset = outX * strideDims[0];
+          const std::size_t strideYoffset = outY * strideDims[1];
+          I poolValue(0.0);
+          bool valid = false;
+          if (static_cast<signedsize>(dims[3]) + negStrideY < 0){
+            kernelYMax = 0;
+          }
+          else if(kernelYMax > kernelDims[1]){
+            kernelYMax = kernelDims[1];
+          }
+          for (unsigned int kY = kernelYMin; kY < kernelYMax ; ++kY){
+            for (unsigned int kX = kernelXMin; kX < kernelXMax; ++kX){
+              // Apply dilation factor to kernel indices
+              const std::size_t dilatedkernelX = kX * dilations[0];
+              const std::size_t dilatedkernelY = kY * dilations[1];
+              // Ensure indices are within bounds
+              auto inputXPostDilation = strideXoffset + dilatedkernelX;
+              auto inputYPostDilation = strideYoffset + dilatedkernelY;
+              if (inputXPostDilation < dims[2] && inputYPostDilation < dims[3]){
+                const I inputValue = input[
+		    inputBaseIndex + inputXPostDilation * dims[3] 
+		    + inputYPostDilation
+                ];
+                if (!valid || inputValue > poolValue) {
+                  poolValue = inputValue;
+                  valid = true;
                 }
+              }
             }
+          }
+          output[outputIndex] = poolValue;
         }
+      }
     }
-}
-
+  }
+} 
 
 // Kernels registration to implementation entry point
 REGISTRAR(MaxPoolingImpl2D_cpu,
-- 
GitLab