From 9349f51e37856294ba54612a3e6d802762980a09 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Thu, 31 Oct 2024 23:01:43 +0000
Subject: [PATCH] [Upd] Conv[DW] 2D kernels

---
 .../operator/ConvDepthWiseImpl_kernels.hpp    |  92 ++++++++++-----
 .../backend/cpu/operator/ConvImpl_kernels.hpp | 105 ++++++++++++------
 2 files changed, 138 insertions(+), 59 deletions(-)

diff --git a/include/aidge/backend/cpu/operator/ConvDepthWiseImpl_kernels.hpp b/include/aidge/backend/cpu/operator/ConvDepthWiseImpl_kernels.hpp
index c39cf9cc..2ab00a9d 100644
--- a/include/aidge/backend/cpu/operator/ConvDepthWiseImpl_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/ConvDepthWiseImpl_kernels.hpp
@@ -150,30 +150,24 @@ void ConvDepthWiseImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& stri
     // weight (outCh, ch, kernelX, kernelY)
     // does not take Dilation attribute into account
     using signedsize = std::make_signed<std::size_t>::type;
-    for (std::size_t batch = 0; batch < inputDims[0]; ++batch) {
-        for (std::size_t ch = 0; ch < inputDims[1]; ++ch) {
-            const std::size_t oIndex = (ch + batch*inputDims[1]) * oxSize * oySize;
-            B biasVal = (biases != nullptr) ? biases[ch] : B(0);
-            std::fill(output + oIndex, output+(oIndex+oxSize*oySize), biasVal);
-            const std::size_t iIndex = (ch + batch*inputDims[1]) * inputDims[2] * inputDims[3];
-            const std::size_t wIndex = ch * kernelDims[0] * kernelDims[1];
-            for (std::size_t ox = 0; ox < oxSize; ++ox) {
-                // const signedsize difx = static_cast<signedsize>(- ox * strideDims[0]);
-                // const std::size_t sxMin = static_cast<std::size_t>(std::max(difx, signedsize(0)));
-                // const std::size_t sxMax = (static_cast<signedsize>(inputDims[2]) + difx) < 0 ? 0 : ((inputDims[2] + difx) > kernelDims[0] ? kernelDims[0] : inputDims[2] + difx);
-                const std::size_t sxMin = 0;
-                const std::size_t sxMax = dilated_kernel_x;
-                for (std::size_t oy = 0; oy < oySize; ++oy) {
-                    // const signedsize dify = static_cast<signedsize>(- oy * strideDims[1]);
-                    // const std::size_t syMin = static_cast<std::size_t>(std::max(dify, signedsize(0)));
-                    // const std::size_t syMax = (static_cast<signedsize>(inputDims[3]) + dify) < 0 ? 0 : ((inputDims[3] + dify) > kernelDims[1] ? kernelDims[1] : inputDims[3] + dify);
-                    const std::size_t syMin = 0;
-                    const std::size_t syMax = dilated_kernel_y;
-                    const std::size_t oIndexFull = oIndex + ox*oySize + oy;
-                    const signedsize ix = static_cast<signedsize>(ox * strideDims[0]);
-                    const signedsize iy = static_cast<signedsize>(oy * strideDims[1]);
-
-                    if (sxMin == 0 && syMin == 0 && sxMax == 3 && syMax == 3) {
+    const std::size_t outChannels_s =  oxSize * oySize;
+
+    if (dilated_kernel_x ==3 && dilated_kernel_y == 3) {
+        for (std::size_t batch = 0; batch < inputDims[0]; ++batch) {
+            for (std::size_t ch = 0; ch < inputDims[1]; ++ch) {
+
+                B biasVal = (biases != nullptr) ? biases[ch] : B(0);
+                std::fill(output, output + outChannels_s, biasVal);
+
+                const std::size_t iIndex = (ch + batch*inputDims[1]) * inputDims[2] * inputDims[3];
+                const std::size_t wIndex = ch * 9;
+
+                for (std::size_t ox = 0; ox < oxSize; ++ox) {
+                    for (std::size_t oy = 0; oy < oySize; ++oy) {
+                        const std::size_t oIndexFull = ox*oySize + oy;
+                        const signedsize ix = static_cast<signedsize>(ox * strideDims[0]);
+                        const signedsize iy = static_cast<signedsize>(oy * strideDims[1]);
+
                         output[oIndexFull] +=  (weights[wIndex + 0*kernelDims[1] + 0] * input[iIndex + static_cast<std::size_t>(ix+0)*inputDims[3] + static_cast<std::size_t>(iy+0)] +
                                                 weights[wIndex + 0*kernelDims[1] + 1] * input[iIndex + static_cast<std::size_t>(ix+0)*inputDims[3] + static_cast<std::size_t>(iy+1)] +
                                                 weights[wIndex + 0*kernelDims[1] + 2] * input[iIndex + static_cast<std::size_t>(ix+0)*inputDims[3] + static_cast<std::size_t>(iy+2)] +
@@ -183,9 +177,51 @@ void ConvDepthWiseImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& stri
                                                 weights[wIndex + 2*kernelDims[1] + 0] * input[iIndex + static_cast<std::size_t>(ix+2)*inputDims[3] + static_cast<std::size_t>(iy+0)] +
                                                 weights[wIndex + 2*kernelDims[1] + 1] * input[iIndex + static_cast<std::size_t>(ix+2)*inputDims[3] + static_cast<std::size_t>(iy+1)] +
                                                 weights[wIndex + 2*kernelDims[1] + 2] * input[iIndex + static_cast<std::size_t>(ix+2)*inputDims[3] + static_cast<std::size_t>(iy+2)]);
-                    } else {
-                        for (std::size_t sx = sxMin; sx*dilationDims[0] < sxMax; ++sx) {
-                            for (std::size_t sy = syMin; sy*dilationDims[1] < syMax; ++sy) {
+                    }
+                }
+                output += outChannels_s;
+            }
+        }
+    } else if (dilated_kernel_x == 1 && dilated_kernel_y == 1) {
+        for (std::size_t batch = 0; batch < inputDims[0]; ++batch) {
+            for (std::size_t ch = 0; ch < inputDims[1]; ++ch) {
+
+                B biasVal = (biases != nullptr) ? biases[ch] : B(0);
+                std::fill(output, output + outChannels_s, biasVal);
+
+                const std::size_t iIndex = (ch + batch*inputDims[1]) * inputDims[2] * inputDims[3];
+                const std::size_t wIndex = ch;
+                for (std::size_t ox = 0; ox < oxSize; ++ox) {
+                    for (std::size_t oy = 0; oy < oySize; ++oy) {
+
+                        const std::size_t oIndexFull = ox*oySize + oy;
+                        const signedsize ix = static_cast<signedsize>(ox * strideDims[0]);
+                        const signedsize iy = static_cast<signedsize>(oy * strideDims[1]);
+                        output[oIndexFull] += weights[wIndex] * input[iIndex + static_cast<std::size_t>(ix)*inputDims[3] + static_cast<std::size_t>(iy)];
+                    }
+                }
+            }
+            output += outChannels_s;
+        }
+    } else {
+        for (std::size_t batch = 0; batch < inputDims[0]; ++batch) {
+            for (std::size_t ch = 0; ch < inputDims[1]; ++ch) {
+
+                B biasVal = (biases != nullptr) ? biases[ch] : B(0);
+                std::fill(output, output+outChannels_s, biasVal);
+
+                const std::size_t iIndex = (ch + batch*inputDims[1]) * inputDims[2] * inputDims[3];
+                const std::size_t wIndex = ch * kernelDims[0] * kernelDims[1];
+
+                for (std::size_t ox = 0; ox < oxSize; ++ox) {
+                    for (std::size_t oy = 0; oy < oySize; ++oy) {
+
+                        const std::size_t oIndexFull = ox*oySize + oy;
+                        const signedsize ix = static_cast<signedsize>(ox * strideDims[0]);
+                        const signedsize iy = static_cast<signedsize>(oy * strideDims[1]);
+
+                        for (std::size_t sx = 0; sx*dilationDims[0] < dilated_kernel_x; ++sx) {
+                            for (std::size_t sy = 0; sy*dilationDims[1] < dilated_kernel_y; ++sy) {
                                 output[oIndexFull] += weights[wIndex + sx*kernelDims[1] + sy] *
                                                         input[iIndex + static_cast<std::size_t>(ix+static_cast<signedsize>(sx*dilationDims[0]))*inputDims[3] + static_cast<std::size_t>(iy+static_cast<signedsize>(sy*dilationDims[1]))];
                             }
@@ -193,10 +229,12 @@ void ConvDepthWiseImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& stri
                     }
                 }
             }
+            output += outChannels_s;
         }
     }
 }
 
+
 // Kernels registration to implementation entry point
 REGISTRAR(ConvDepthWiseImpl2D_cpu,
     {{DataType::Any, DataFormat::NCHW}, {DataType::Float32, DataFormat::NCHW}},
diff --git a/include/aidge/backend/cpu/operator/ConvImpl_kernels.hpp b/include/aidge/backend/cpu/operator/ConvImpl_kernels.hpp
index 71538eaa..0cf079a9 100644
--- a/include/aidge/backend/cpu/operator/ConvImpl_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/ConvImpl_kernels.hpp
@@ -141,15 +141,15 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims,
     O *output = static_cast<O *>(output_);
 
     // output H size
+    const DimSize_t dilated_kernel_x = dilationDims[0]*(kernelDims[0] - 1) + 1;
     const std::size_t oxSize =
-            static_cast<std::size_t>(std::floor(static_cast<float>(inputDims[2] - dilationDims[0]*(kernelDims[0] - 1) - 1 + strideDims[0]) /
+            static_cast<std::size_t>(std::floor(static_cast<float>(inputDims[2] - dilated_kernel_x + strideDims[0]) /
                                 static_cast<float>(strideDims[0])));
-    const DimSize_t dilated_kernel_x = dilationDims[0]*(kernelDims[0] - 1) + 1;
     // output W size
+    const DimSize_t dilated_kernel_y = dilationDims[1]*(kernelDims[1] - 1) + 1;
     const std::size_t oySize =
-            static_cast<std::size_t>(std::floor(static_cast<float>(inputDims[3] - dilationDims[1]*(kernelDims[1] - 1) - 1 + strideDims[1]) /
+            static_cast<std::size_t>(std::floor(static_cast<float>(inputDims[3] - dilated_kernel_y + strideDims[1]) /
                                 static_cast<float>(strideDims[1])));
-    const DimSize_t dilated_kernel_y = dilationDims[1]*(kernelDims[1] - 1) + 1;
 
 
     // TODO: kernel computation
@@ -158,51 +158,92 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims,
     // weight (outCh, inCh, kernelX, kernelY)
     // does not take Dilation attribute into account
     const std::size_t outChannels_s =  oxSize * oySize;
+    using signedsize = std::make_signed<std::size_t>::type;
 
-    for (std::size_t batch = 0; batch < inputDims[0]; ++batch) {
-        for (std::size_t outCh = 0; outCh < outChannels; ++outCh) {
-            // If bias = nullptr, set B(0)
-            B biasVal = (biases != nullptr) ? biases[outCh] : B(0);
-            std::fill(output, output+outChannels_s, biasVal);
+    if (dilated_kernel_x == 3 && dilated_kernel_y == 3) {
+        for (std::size_t batch = 0; batch < inputDims[0]; ++batch) {
+            for (std::size_t outCh = 0; outCh < outChannels; ++outCh) {
+                // If bias = nullptr, set B(0)
+                B biasVal = (biases != nullptr) ? biases[outCh] : B(0);
+                std::fill(output, output+outChannels_s, biasVal);
+                for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) {
+                    const std::size_t iIndex = (inCh + batch*inputDims[1]) * inputDims[2] * inputDims[3];
+                    const std::size_t wIndex = (inCh + outCh*inputDims[1]) * 9;
+                    for (std::size_t ox = 0; ox < oxSize; ++ox) {
+                        for (std::size_t oy = 0; oy < oySize; ++oy) {
+                            const std::size_t oIndexFull = ox*oySize + oy;
+                            const signedsize ix = static_cast<signedsize>(ox * strideDims[0]);
+                            const signedsize iy = static_cast<signedsize>(oy * strideDims[1]);
 
-            for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) {
-                const std::size_t iIndex = (inCh + batch*inputDims[1]) * inputDims[2] * inputDims[3];
-                const std::size_t wIndex = (inCh + outCh*inputDims[1]) * kernelDims[0] * kernelDims[1];
-                for (std::size_t ox = 0; ox < oxSize; ++ox) {
+                            output[oIndexFull] += (weights[wIndex + 0*kernelDims[1] + 0] * input[iIndex + static_cast<std::size_t>(ix+0)*inputDims[3] + static_cast<std::size_t>(iy+0)] +
+                                                weights[wIndex + 0*kernelDims[1] + 1] * input[iIndex + static_cast<std::size_t>(ix+0)*inputDims[3] + static_cast<std::size_t>(iy+1)] +
+                                                weights[wIndex + 0*kernelDims[1] + 2] * input[iIndex + static_cast<std::size_t>(ix+0)*inputDims[3] + static_cast<std::size_t>(iy+2)] +
+                                                weights[wIndex + 1*kernelDims[1] + 0] * input[iIndex + static_cast<std::size_t>(ix+1)*inputDims[3] + static_cast<std::size_t>(iy+0)] +
+                                                weights[wIndex + 1*kernelDims[1] + 1] * input[iIndex + static_cast<std::size_t>(ix+1)*inputDims[3] + static_cast<std::size_t>(iy+1)] +
+                                                weights[wIndex + 1*kernelDims[1] + 2] * input[iIndex + static_cast<std::size_t>(ix+1)*inputDims[3] + static_cast<std::size_t>(iy+2)] +
+                                                weights[wIndex + 2*kernelDims[1] + 0] * input[iIndex + static_cast<std::size_t>(ix+2)*inputDims[3] + static_cast<std::size_t>(iy+0)] +
+                                                weights[wIndex + 2*kernelDims[1] + 1] * input[iIndex + static_cast<std::size_t>(ix+2)*inputDims[3] + static_cast<std::size_t>(iy+1)] +
+                                                weights[wIndex + 2*kernelDims[1] + 2] * input[iIndex + static_cast<std::size_t>(ix+2)*inputDims[3] + static_cast<std::size_t>(iy+2)]);
+                        }
+                    }
+                }
+                output += outChannels_s;
+            }
+        }
+    } else if (dilated_kernel_x == 1 && dilated_kernel_y == 1) {
+        for (std::size_t batch = 0; batch < inputDims[0]; ++batch) {
+            for (std::size_t outCh = 0; outCh < outChannels; ++outCh) {
+                // If bias = nullptr, set B(0)
+                B biasVal = (biases != nullptr) ? biases[outCh] : B(0);
+                std::fill(output, output+outChannels_s, biasVal);
+                for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) {
+                    const std::size_t iIndex = (inCh + batch*inputDims[1]) * inputDims[2] * inputDims[3];
+                    const std::size_t wIndex = (inCh + outCh*inputDims[1]);
+                    for (std::size_t ox = 0; ox < oxSize; ++ox) {
+                        for (std::size_t oy = 0; oy < oySize; ++oy) {
+                            const std::size_t oIndexFull = ox*oySize + oy;
+                            const signedsize ix = static_cast<signedsize>(ox * strideDims[0]);
+                            const signedsize iy = static_cast<signedsize>(oy * strideDims[1]);
+
+                            output[oIndexFull] += weights[wIndex] * input[iIndex + static_cast<std::size_t>(ix)*inputDims[3] + static_cast<std::size_t>(iy)];
+                        }
+                    }
+                }
+                output += outChannels_s;
+            }
+        }
+    } else {
+        for (std::size_t batch = 0; batch < inputDims[0]; ++batch) {
+            for (std::size_t outCh = 0; outCh < outChannels; ++outCh) {
+                // If bias = nullptr, set B(0)
+                B biasVal = (biases != nullptr) ? biases[outCh] : B(0);
+                std::fill(output, output+outChannels_s, biasVal);
+                for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) {
+                    const std::size_t iIndex = (inCh + batch*inputDims[1]) * inputDims[2] * inputDims[3];
+                    const std::size_t wIndex = (inCh + outCh*inputDims[1]) * kernelDims[0] * kernelDims[1];
+                    for (std::size_t ox = 0; ox < oxSize; ++ox) {
+                        for (std::size_t oy = 0; oy < oySize; ++oy) {
+                            const std::size_t oIndexFull = ox*oySize + oy;
+                            const signedsize ix = static_cast<signedsize>(ox * strideDims[0]);
+                            const signedsize iy = static_cast<signedsize>(oy * strideDims[1]);
 
-                    for (std::size_t oy = 0; oy < oySize; ++oy) {
-
-                        const std::size_t oIndexFull = ox*oySize + oy;
-                        const size_t ix = ox * strideDims[0];
-                        const size_t iy = oy * strideDims[1];
-
-                        if (kernelDims[0] == 3 && kernelDims[1] == 3 && dilationDims[0] == 1 && dilationDims[1] == 1) {
-                            output[oIndexFull] += (weights[wIndex] * input[iIndex + static_cast<std::size_t>(ix)*inputDims[3] + static_cast<std::size_t>(iy)] +
-                                                   weights[wIndex + 1] * input[iIndex + static_cast<std::size_t>(ix)*inputDims[3] + static_cast<std::size_t>(iy+1)] +
-                                                   weights[wIndex + 2] * input[iIndex + static_cast<std::size_t>(ix)*inputDims[3] + static_cast<std::size_t>(iy+2)] +
-                                                   weights[wIndex + kernelDims[1]] * input[iIndex + static_cast<std::size_t>(ix+1)*inputDims[3] + static_cast<std::size_t>(iy)] +
-                                                   weights[wIndex + kernelDims[1] + 1] * input[iIndex + static_cast<std::size_t>(ix+1)*inputDims[3] + static_cast<std::size_t>(iy+1)] +
-                                                   weights[wIndex + kernelDims[1] + 2] * input[iIndex + static_cast<std::size_t>(ix+1)*inputDims[3] + static_cast<std::size_t>(iy+2)] +
-                                                   weights[wIndex + 2*kernelDims[1]] * input[iIndex + static_cast<std::size_t>(ix+2)*inputDims[3] + static_cast<std::size_t>(iy)] +
-                                                   weights[wIndex + 2*kernelDims[1] + 1] * input[iIndex + static_cast<std::size_t>(ix+2)*inputDims[3] + static_cast<std::size_t>(iy+1)] +
-                                                   weights[wIndex + 2*kernelDims[1] + 2] * input[iIndex + static_cast<std::size_t>(ix+2)*inputDims[3] + static_cast<std::size_t>(iy+2)]);
-                        } else {
                             for (std::size_t sx = 0; sx*dilationDims[0] < dilated_kernel_x; ++sx) {
                                 for (std::size_t sy = 0; sy*dilationDims[1] < dilated_kernel_y; ++sy) {
                                     output[oIndexFull] += weights[wIndex + sx*kernelDims[1] + sy] *
-                                                            input[iIndex + (ix + (sx*dilationDims[0]))*inputDims[3] + (iy + (sy*dilationDims[1]))];
+                                                            input[iIndex + static_cast<std::size_t>(ix+static_cast<signedsize>(sx*dilationDims[0]))*inputDims[3] + static_cast<std::size_t>(iy+static_cast<signedsize>(sy*dilationDims[1]))];
                                 }
                             }
                         }
                     }
                 }
+                output += outChannels_s;
             }
-            output += outChannels_s;
         }
     }
 }
 
 
+
 // Kernels registration to implementation entry point
 REGISTRAR(ConvImpl2D_cpu,
     {{DataType::Any, DataFormat::NCHW}, {DataType::Float32, DataFormat::NCHW}},
-- 
GitLab