From a814fc02df5cb6f5850a99a350d2e6c986da5838 Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Mon, 3 Feb 2025 10:11:02 +0100
Subject: [PATCH] handle ceil_mode in pooling kernels

---
 .../cpu/operator/AvgPoolingImpl_kernels.hpp   | 56 ++++++++++++-------
 .../cpu/operator/MaxPoolingImpl_kernels.hpp   | 20 ++++---
 unit_tests/operator/Test_AvgPoolingImpl.cpp   | 35 +++++++++++-
 3 files changed, 82 insertions(+), 29 deletions(-)

diff --git a/include/aidge/backend/cpu/operator/AvgPoolingImpl_kernels.hpp b/include/aidge/backend/cpu/operator/AvgPoolingImpl_kernels.hpp
index 68dbfbe7..78f8446a 100644
--- a/include/aidge/backend/cpu/operator/AvgPoolingImpl_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/AvgPoolingImpl_kernels.hpp
@@ -43,15 +43,20 @@ void AvgPoolingImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideD
     const I *input = static_cast<const I *>(input_);
     O *output = static_cast<O *>(output_);
 
-    // Calculate output dimensions based on ceilMode and dilations
-    auto compute_output_size = [&](DimSize_t inputDim, DimSize_t kernelDim, DimSize_t stride, DimSize_t dilation) {
-        DimSize_t effectiveKernelDim = (kernelDim - 1) * dilation + 1;
-        float result = static_cast<float>(inputDim - effectiveKernelDim + stride) / static_cast<float>(stride);
-        return ceilMode ? static_cast<DimSize_t>(std::ceil(result)) : static_cast<DimSize_t>(std::floor(result));
-    };
-
-    const std::size_t oxSize = compute_output_size(dims[2], kernelDims[0], strideDims[0], dilations[0]);
-    const std::size_t oySize = compute_output_size(dims[3], kernelDims[1], strideDims[1], dilations[1]);
+    // output H size
+    const std::size_t oxSize = 
+        ceilMode 
+        ? static_cast<std::size_t>(std::ceil(static_cast<float>(dims[2] - (kernelDims[0] - 1) * dilations[0] - 1 + strideDims[0]) /
+                                            static_cast<float>(strideDims[0])))
+        : static_cast<std::size_t>(std::floor(static_cast<float>(dims[2] - (kernelDims[0] - 1) * dilations[0] - 1 + strideDims[0]) /
+                                            static_cast<float>(strideDims[0])));
+    // output W size
+    const std::size_t oySize = 
+        ceilMode 
+        ? static_cast<std::size_t>(std::ceil(static_cast<float>(dims[3] - (kernelDims[1] - 1) * dilations[1] - 1 + strideDims[1]) /
+                                            static_cast<float>(strideDims[1])))
+        : static_cast<std::size_t>(std::floor(static_cast<float>(dims[3] - (kernelDims[1] - 1) * dilations[1] - 1 + strideDims[1]) /
+                                            static_cast<float>(strideDims[1])));
 
     using signedsize = std::make_signed<std::size_t>::type;
 
@@ -59,30 +64,39 @@ void AvgPoolingImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideD
         for (std::size_t ch = 0; ch < dims[1]; ++ch) {
             const std::size_t oIndex = (ch + batch * dims[1]) * oxSize * oySize;
             const std::size_t iIndex = (ch + batch * dims[1]) * dims[2] * dims[3];
-            std::fill(output + oIndex, output + (oIndex + oxSize * oySize), 0);
 
             for (std::size_t ox = 0; ox < oxSize; ++ox) {
-                const signedsize startx = static_cast<signedsize>(ox * strideDims[0]) - (dilations[0] - 1);
-                const std::size_t sxMin = static_cast<std::size_t>(std::max(startx, signedsize(0)));
-                const std::size_t sxMax = std::min(dims[2], static_cast<std::size_t>(startx + kernelDims[0] * dilations[0]));
+                const signedsize difx = static_cast<signedsize>(-ox * strideDims[0]);
+                const std::size_t sxMin = static_cast<std::size_t>(std::max(difx, signedsize(0)));
+                const std::size_t sxMax = (static_cast<signedsize>(dims[2]) + difx) < 0 ? 0 : ((dims[2] + difx) > kernelDims[0] ? kernelDims[0] : dims[2] + difx);
 
                 for (std::size_t oy = 0; oy < oySize; ++oy) {
-                    const signedsize starty = static_cast<signedsize>(oy * strideDims[1]) - (dilations[1] - 1);
-                    const std::size_t syMin = static_cast<std::size_t>(std::max(starty, signedsize(0)));
-                    const std::size_t syMax = std::min(dims[3], static_cast<std::size_t>(starty + kernelDims[1] * dilations[1]));
+                    const signedsize dify = static_cast<signedsize>(-oy * strideDims[1]);
+                    const std::size_t syMin = static_cast<std::size_t>(std::max(dify, signedsize(0)));
+                    const std::size_t syMax = (static_cast<signedsize>(dims[3]) + dify) < 0 ? 0 : ((dims[3] + dify) > kernelDims[1] ? kernelDims[1] : dims[3] + dify);
 
                     const std::size_t oIndexFull = oIndex + ox * oySize + oy;
+                    const std::size_t ix = ox * strideDims[0];
+                    const std::size_t iy = oy * strideDims[1];
+
                     O sum = static_cast<O>(0);
                     std::size_t count = 0;
 
-                    for (std::size_t sx = sxMin; sx < sxMax; sx += dilations[0]) {
-                        for (std::size_t sy = syMin; sy < syMax; sy += dilations[1]) {
-                            sum += static_cast<O>(input[iIndex + sx * dims[3] + sy]);
-                            ++count;
+                    for (unsigned int sy = syMin; sy < syMax; ++sy) {
+                        for (unsigned int sx = sxMin; sx < sxMax; ++sx) {
+                            // Apply dilation factor
+                            const std::size_t dilated_sx = sx * dilations[0];
+                            const std::size_t dilated_sy = sy * dilations[1];
+
+                            // Ensure within bounds
+                            if ((ix + dilated_sx) < dims[2] && (iy + dilated_sy) < dims[3]) {
+                                sum += static_cast<O>(input[iIndex + (ix + dilated_sx) * dims[3] + (iy + dilated_sy)]);
+                                ++count;
+                            }
                         }
                     }
 
-                    output[oIndexFull] = sum / static_cast<O>(count);
+                    output[oIndexFull] = count > 0 ? sum / static_cast<O>(count) : 0;
                 }
             }
         }
diff --git a/include/aidge/backend/cpu/operator/MaxPoolingImpl_kernels.hpp b/include/aidge/backend/cpu/operator/MaxPoolingImpl_kernels.hpp
index 250b11b0..d5ac02fe 100644
--- a/include/aidge/backend/cpu/operator/MaxPoolingImpl_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/MaxPoolingImpl_kernels.hpp
@@ -36,7 +36,7 @@ template <class I, class O>
 void MaxPoolingImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims,
                                         const std::array<DimSize_t, 2>& kernelDims,
                                         const std::array<DimSize_t, 2>& dilations,
-                                        const bool /*ceilMode*/,
+                                        const bool ceilMode,
                                         const std::array<DimSize_t, 4> &dims,
                                         const void *input_,
                                         void *output_) {
@@ -44,13 +44,19 @@ void MaxPoolingImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideD
     O *output = static_cast<O *>(output_);
 
     // output H size
-    const std::size_t oxSize =
-            static_cast<std::size_t>(std::floor(static_cast<float>(dims[2] - (kernelDims[0] - 1) * dilations[0] - 1 + strideDims[0]) /
-                                static_cast<float>(strideDims[0])));
+    const std::size_t oxSize = 
+        ceilMode 
+        ? static_cast<std::size_t>(std::ceil(static_cast<float>(dims[2] - (kernelDims[0] - 1) * dilations[0] - 1 + strideDims[0]) /
+                                            static_cast<float>(strideDims[0])))
+        : static_cast<std::size_t>(std::floor(static_cast<float>(dims[2] - (kernelDims[0] - 1) * dilations[0] - 1 + strideDims[0]) /
+                                            static_cast<float>(strideDims[0])));
     // output W size
-    const std::size_t oySize =
-            static_cast<std::size_t>(std::floor(static_cast<float>(dims[3] - (kernelDims[1] - 1) * dilations[1] - 1 + strideDims[1]) /
-                                static_cast<float>(strideDims[1])));
+    const std::size_t oySize = 
+        ceilMode 
+        ? static_cast<std::size_t>(std::ceil(static_cast<float>(dims[3] - (kernelDims[1] - 1) * dilations[1] - 1 + strideDims[1]) /
+                                            static_cast<float>(strideDims[1])))
+        : static_cast<std::size_t>(std::floor(static_cast<float>(dims[3] - (kernelDims[1] - 1) * dilations[1] - 1 + strideDims[1]) /
+                                            static_cast<float>(strideDims[1])));
 
     using signedsize = std::make_signed<std::size_t>::type;
     for (std::size_t batch = 0; batch < dims[0]; ++batch) {
diff --git a/unit_tests/operator/Test_AvgPoolingImpl.cpp b/unit_tests/operator/Test_AvgPoolingImpl.cpp
index 372febc6..21a7a680 100644
--- a/unit_tests/operator/Test_AvgPoolingImpl.cpp
+++ b/unit_tests/operator/Test_AvgPoolingImpl.cpp
@@ -110,5 +110,38 @@ TEST_CASE("[cpu/operator] AvgPooling(forward)", "[AvgPooling][CPU]") {
             REQUIRE(std::abs(outPtr[i] - expectedOutPtr[i]) < 0.00001);
         }
     }
-    // std::cout << static_cast<Tensor>((*op)["weight"])[0][0][0][0] << std::endl;
+    SECTION("Dilations") {
+        std::shared_ptr<Tensor> myInput3 = std::make_shared<Tensor>(Array4D<float,1,1,5,5> { // NCHW
+        {
+            {
+                {{ 1,  2,  3,  4,  5},
+                { 6,  7,  8,  9, 10},
+                {11, 12, 13, 14, 15},
+                {16, 17, 18, 19, 20},
+                {21, 22, 23, 24, 25}}
+            }
+        }
+        });
+
+        // Dilation of 2 means we take every second element in the window
+        std::shared_ptr<Node> myAvgPool = AvgPooling({2,2}, "mycdw", {1,1}, {2,2}); 
+        auto op = std::static_pointer_cast<AvgPooling_Op<2>>(myAvgPool -> getOperator());
+
+        std::shared_ptr<Tensor> myOutput3 = std::make_shared<Tensor>(Array4D<float,1,1,3,3> {
+            {
+                {
+                    {{  7,  8,  9},
+                    { 12, 13, 14},
+                    { 17, 18, 19}}
+                }
+            }
+        });
+
+        op->associateInput(0, myInput3);
+        op->setDataType(DataType::Float32);
+        op->setBackend("cpu");
+        myAvgPool->forward();
+        op->getOutput(0)->print();
+        REQUIRE(*(op->getOutput(0)) == *myOutput3);
+    }
 }
\ No newline at end of file
-- 
GitLab