From 6d67e92e30f6c9d22cf9e0bbba2db013e7fec0c9 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Sun, 30 Jun 2024 23:46:25 +0200
Subject: [PATCH] Fixes

---
 src/operator/Unfold.cpp | 23 +++++++++++++----------
 1 file changed, 13 insertions(+), 10 deletions(-)

diff --git a/src/operator/Unfold.cpp b/src/operator/Unfold.cpp
index eed950ac2..5b651846b 100644
--- a/src/operator/Unfold.cpp
+++ b/src/operator/Unfold.cpp
@@ -31,6 +31,7 @@ void Aidge::Unfold_OpImpl<DIM>::forward() {
     const auto strideDims = op.template getAttr<UnfoldAttr::StrideDims>();
     const DimSize_t inHeight = op.getInput(0)->dims()[2];
     const DimSize_t inWidth = op.getInput(0)->dims()[3];
+    const DimSize_t inChannels = op.getInput(0)->dims()[1];
 
     const DimSize_t kernelExtentHeight = op.template getAttr<UnfoldAttr::DilationDims>()[0] *
                                             (op.template getAttr<UnfoldAttr::KernelDims>()[0] - 1) + 1;
@@ -44,19 +45,21 @@ void Aidge::Unfold_OpImpl<DIM>::forward() {
                             static_cast<float>(op.template getAttr<UnfoldAttr::StrideDims>()[1])));
     const DimSize_t outChannels = op.getOutput(0)->dims()[1];
 
-    for (DimSize_t outC = 0; outC < outChannels; ++outC) {
-        const auto inOffsetH = outC % kernelDims[0];
-        const auto inOffsetW = (outC / kernelDims[0]) % kernelDims[1];
-        const auto inC = outC / kernelDims[0] / kernelDims[1];
+    for (DimSize_t n = 0; n < op.getOutput(0)->dims()[0]; ++n) {
+        for (DimSize_t outC = 0; outC < outChannels; ++outC) {
+            const auto inOffsetH = outC % kernelDims[1];
+            const auto inOffsetW = (outC / kernelDims[1]) % kernelDims[0];
+            const auto inC = outC / kernelDims[0] / kernelDims[1];
 
-        for (DimSize_t outH = 0; outH < outHeight; ++outH) {
-            const auto inH = outH * strideDims[1] + inOffsetH * dilationDims[1];
+            for (DimSize_t outH = 0; outH < outHeight; ++outH) {
+                const auto inH = outH * strideDims[0] + inOffsetH * dilationDims[0];
 
-            for (DimSize_t outW = 0; outW < outWidth; ++outW) {
-                const auto inW = outW * strideDims[0] + inOffsetW * dilationDims[0];
+                for (DimSize_t outW = 0; outW < outWidth; ++outW) {
+                    const auto inW = outW * strideDims[1] + inOffsetW * dilationDims[1];
 
-                op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr((inC * inHeight + inH) * inWidth + inW), 1,
-                    (outC * outHeight + outH) * outWidth + outW);
+                    op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(((n * inChannels + inC) * inHeight + inH) * inWidth + inW), 1,
+                        ((n * outChannels + outC) * outHeight + outH) * outWidth + outW);
+                }
             }
         }
     }
-- 
GitLab