diff --git a/include/aidge/backend/cpu/operator/FoldImpl.hpp b/include/aidge/backend/cpu/operator/FoldImpl.hpp
index b258745ef772f12a744ff2e43dcdf5887825f17f..a0c7e509cddbd1b33b8360aed4a8bbce4a39dcac 100644
--- a/include/aidge/backend/cpu/operator/FoldImpl.hpp
+++ b/include/aidge/backend/cpu/operator/FoldImpl.hpp
@@ -27,7 +27,7 @@ namespace Aidge {
 class FoldImpl2DForward_cpu
     : public Registrable<FoldImpl2DForward_cpu,
                          std::tuple<DataType, DataType>,
-                         void(const Fold_Op<2>::Attrs &, const std::array<DimSize_t, 4> &, const void *,
+                         void(const Fold_Op<2>::Attrs &, const std::vector<DimSize_t> &, const void *,
                               void *)> {};
 
 class FoldImpl2D_cpu : public OperatorImpl {
diff --git a/include/aidge/backend/cpu/operator/FoldImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/FoldImpl_forward_kernels.hpp
index 5caad1472a6719042431300ce3bd16e3bc9669f8..81b96d4b3b1e9c59ad424967eae54b75224f83bd 100644
--- a/include/aidge/backend/cpu/operator/FoldImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/FoldImpl_forward_kernels.hpp
@@ -23,7 +23,7 @@
 
 namespace Aidge {
 template <class I, class O>
-void FoldImpl2D_cpu_forward_kernel(const Fold_Op<2>::Attrs &attrs, const std::array<DimSize_t, 4> &dims,
+void FoldImpl2D_cpu_forward_kernel(const Fold_Op<2>::Attrs &attrs, const std::vector<DimSize_t> &dims,
                                        const void *input_, void *output_)
 {
     const I *input = static_cast<const I *>(input_);
@@ -45,22 +45,26 @@ void FoldImpl2D_cpu_forward_kernel(const Fold_Op<2>::Attrs &attrs, const std::ar
     const DimSize_t outWidth = 1 + static_cast<DimSize_t>(
                     floor(static_cast<float>(inWidth - kernelExtentWidth) /
                             static_cast<float>(strideDims[1])));
-    const DimSize_t outChannels = dims[1];
+    const DimSize_t outChannels = dims[dims.size() - 2];
+    const DimSize_t inChannels = outChannels / kernelDims[0] / kernelDims[1];
 
-    std::fill_n(output, outHeight * outWidth * outChannels, O(0));
+    std::fill_n(output, dims[0] * outHeight * outWidth * outChannels, O(0));
 
-    for (DimSize_t outC = 0; outC < outChannels; ++outC) {
-        const auto inOffsetH = outC % kernelDims[0];
-        const auto inOffsetW = (outC / kernelDims[0]) % kernelDims[1];
-        const auto inC = outC / kernelDims[0] / kernelDims[1];
+    for (DimSize_t n = 0; n < dims[0]; ++n) {
+        for (DimSize_t outC = 0; outC < outChannels; ++outC) {
+            const auto inOffsetH = outC % kernelDims[1];
+            const auto inOffsetW = (outC / kernelDims[1]) % kernelDims[0];
+            const auto inC = outC / kernelDims[0] / kernelDims[1];
 
-        for (DimSize_t outH = 0; outH < outHeight; ++outH) {
-            const auto inH = outH * strideDims[1] + inOffsetH * dilationDims[1];
+            for (DimSize_t outH = 0; outH < outHeight; ++outH) {
+                const auto inH = outH * strideDims[0] + inOffsetH * dilationDims[0];
 
-            for (DimSize_t outW = 0; outW < outWidth; ++outW) {
-                const auto inW = outW * strideDims[0] + inOffsetW * dilationDims[0];
+                for (DimSize_t outW = 0; outW < outWidth; ++outW) {
+                    const auto inW = outW * strideDims[1] + inOffsetW * dilationDims[1];
 
-                output[(inC * inHeight + inH) * inWidth + inW] += input[(outC * outHeight + outH) * outWidth + outW];
+                    output[((n * inChannels + inC) * inHeight + inH) * inWidth + inW] +=
+                        input[((n * outChannels + outC) * outHeight + outH) * outWidth + outW];
+                }
             }
         }
     }
diff --git a/src/operator/FoldImpl.cpp b/src/operator/FoldImpl.cpp
index a885db4c3de9527bb1173600c05eac65cde46e74..bcb0b4b029f4c8cd899e123f62519987fc432a3d 100644
--- a/src/operator/FoldImpl.cpp
+++ b/src/operator/FoldImpl.cpp
@@ -31,7 +31,7 @@ void Aidge::FoldImpl2D_cpu::forward() {
 
     // Call kernel
     kernelFunc(dynamic_cast<const Fold_Op<2>&>(mOp).getStaticAttributes(),
-                        std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<4>(),
+                        std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(),
                         getCPUPtr(mOp.getRawInput(0)),
                         getCPUPtr(mOp.getRawOutput(0)));
 }
diff --git a/unit_tests/operator/Test_FoldImpl.cpp b/unit_tests/operator/Test_FoldImpl.cpp
index 079fca62b729d6b68e81f2891bf85a89defeed68..010b23d395acea79a0b2503e3a43f10a11661f77 100644
--- a/unit_tests/operator/Test_FoldImpl.cpp
+++ b/unit_tests/operator/Test_FoldImpl.cpp
@@ -27,9 +27,9 @@ using namespace Aidge;
 
 TEST_CASE("[cpu/operator] Fold(forward)", "[Fold][CPU]") {
     std::shared_ptr<Node> myUnfold = Unfold({3,3}, "myunfold");
-    std::shared_ptr<Node> myReshape = Reshape({9, 12}, "myreshape");
+    std::shared_ptr<Node> myReshape = Reshape({4, 27}, "myreshape");
     std::shared_ptr<Node> myMatMul = MatMul("mymatmul");
-    std::shared_ptr<Node> myFold = Fold({3,3}, {3,3}, "myfold");
+    std::shared_ptr<Node> myFold = Fold({3,3}, {1,1}, "myfold");
     myUnfold->addChild(myMatMul, 0, 1);
     myReshape->addChild(myMatMul, 0, 0);
     myMatMul->addChild(myFold, 0, 0);
@@ -127,32 +127,32 @@ TEST_CASE("[cpu/operator] Fold(forward)", "[Fold][CPU]") {
     std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array4D<int,2,4,3,3> {
         {
             {
-                {{ 15226,  15577,  15928},
-                { 16981,  17332,  17683},
-                { 18736,  19087,  19438}},
-                {{ 37818,  38898,  39978},
-                { 43218,  44298,  45378},
-                { 48618,  49698,  50778}},
-                {{ 60426,  62235,  64044},
-                { 69471,  71280,  73089},
-                { 78516,  80325,  82134}},
-                {{ 83016,  85554,  88092},
-                { 95706,  98244, 100782},
-                {108396, 110934, 113472}}
+                {{ 15219, 15570, 15921},
+                { 16974, 17325, 17676},
+                { 18729, 19080, 19431}},
+                {{ 37818, 38898, 39978},
+                { 43218, 44298, 45378},
+                { 48618, 49698, 50778}},
+                {{ 60417, 62226, 64035},
+                { 69462, 71271, 73080},
+                { 78507, 80316, 82125}},
+                {{ 83016, 85554, 88092},
+                { 95706, 98244, 100782},
+                { 108396, 110934, 113472}}
             },
             {
-                {{ 41551,  41902,  42253},
-                { 43306,  43657,  44008},
-                { 45061,  45412,  45763}},
-                {{118818, 119898, 120978},
-                {124218, 125298, 126378},
-                {129618, 130698, 131778}},
-                {{196101, 197910, 199719},
-                {205146, 206955, 208764},
-                {214191, 216000, 217809}},
-                {{273366, 275904, 278442},
-                {286056, 288594, 291132},
-                {298746, 301284, 303822}}
+                {{ 41544, 41895, 42246},
+                { 43299, 43650, 44001},
+                { 45054, 45405, 45756}},
+                {{ 118818, 119898, 120978},
+                { 124218, 125298, 126378},
+                { 129618, 130698, 131778}},
+                {{ 196092, 197901, 199710},
+                { 205137, 206946, 208755},
+                { 214182, 215991, 217800}},
+                {{ 273366, 275904, 278442},
+                { 286056, 288594, 291132},
+                { 298746, 301284, 303822}}
             }
         }
     });
@@ -168,12 +168,11 @@ TEST_CASE("[cpu/operator] Fold(forward)", "[Fold][CPU]") {
     g->setDataType(DataType::Int32);
     g->setBackend("cpu");
 
-    g->save("unfold_matmul_fold");
     g->forwardDims();
     g->save("unfold_matmul_fold");
 
     SequentialScheduler scheduler(g);
     scheduler.forward();
-    // op->getOutput(0)->print();
+    opFold->getOutput(0)->print();
     REQUIRE(*(opFold->getOutput(0)) == *myOutput);
 }
\ No newline at end of file