diff --git a/src/operator/Unfold.cpp b/src/operator/Unfold.cpp index eed950ac24b370ba379a090e63d839e85fb59af1..5b651846b3c8744d9bbe594a0578c5c2787f722c 100644 --- a/src/operator/Unfold.cpp +++ b/src/operator/Unfold.cpp @@ -31,6 +31,7 @@ void Aidge::Unfold_OpImpl<DIM>::forward() { const auto strideDims = op.template getAttr<UnfoldAttr::StrideDims>(); const DimSize_t inHeight = op.getInput(0)->dims()[2]; const DimSize_t inWidth = op.getInput(0)->dims()[3]; + const DimSize_t inChannels = op.getInput(0)->dims()[1]; const DimSize_t kernelExtentHeight = op.template getAttr<UnfoldAttr::DilationDims>()[0] * (op.template getAttr<UnfoldAttr::KernelDims>()[0] - 1) + 1; @@ -44,19 +45,21 @@ void Aidge::Unfold_OpImpl<DIM>::forward() { static_cast<float>(op.template getAttr<UnfoldAttr::StrideDims>()[1]))); const DimSize_t outChannels = op.getOutput(0)->dims()[1]; - for (DimSize_t outC = 0; outC < outChannels; ++outC) { - const auto inOffsetH = outC % kernelDims[0]; - const auto inOffsetW = (outC / kernelDims[0]) % kernelDims[1]; - const auto inC = outC / kernelDims[0] / kernelDims[1]; + for (DimSize_t n = 0; n < op.getOutput(0)->dims()[0]; ++n) { + for (DimSize_t outC = 0; outC < outChannels; ++outC) { + const auto inOffsetH = outC % kernelDims[1]; + const auto inOffsetW = (outC / kernelDims[1]) % kernelDims[0]; + const auto inC = outC / kernelDims[0] / kernelDims[1]; - for (DimSize_t outH = 0; outH < outHeight; ++outH) { - const auto inH = outH * strideDims[1] + inOffsetH * dilationDims[1]; + for (DimSize_t outH = 0; outH < outHeight; ++outH) { + const auto inH = outH * strideDims[0] + inOffsetH * dilationDims[0]; - for (DimSize_t outW = 0; outW < outWidth; ++outW) { - const auto inW = outW * strideDims[0] + inOffsetW * dilationDims[0]; + for (DimSize_t outW = 0; outW < outWidth; ++outW) { + const auto inW = outW * strideDims[1] + inOffsetW * dilationDims[1]; - op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr((inC * inHeight + inH) * inWidth + inW), 1, - (outC * outHeight + outH) * outWidth + outW); + op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(((n * inChannels + inC) * inHeight + inH) * inWidth + inW), 1, + ((n * outChannels + outC) * outHeight + outH) * outWidth + outW); + } } } }