Skip to content
Snippets Groups Projects
Commit 6d67e92e authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixes

parent 984eaa14
No related branches found
No related tags found
2 merge requests!212Version 0.3.0,!153Im2col
Pipeline #49588 passed
...@@ -31,6 +31,7 @@ void Aidge::Unfold_OpImpl<DIM>::forward() { ...@@ -31,6 +31,7 @@ void Aidge::Unfold_OpImpl<DIM>::forward() {
const auto strideDims = op.template getAttr<UnfoldAttr::StrideDims>(); const auto strideDims = op.template getAttr<UnfoldAttr::StrideDims>();
const DimSize_t inHeight = op.getInput(0)->dims()[2]; const DimSize_t inHeight = op.getInput(0)->dims()[2];
const DimSize_t inWidth = op.getInput(0)->dims()[3]; const DimSize_t inWidth = op.getInput(0)->dims()[3];
const DimSize_t inChannels = op.getInput(0)->dims()[1];
const DimSize_t kernelExtentHeight = op.template getAttr<UnfoldAttr::DilationDims>()[0] * const DimSize_t kernelExtentHeight = op.template getAttr<UnfoldAttr::DilationDims>()[0] *
(op.template getAttr<UnfoldAttr::KernelDims>()[0] - 1) + 1; (op.template getAttr<UnfoldAttr::KernelDims>()[0] - 1) + 1;
...@@ -44,19 +45,21 @@ void Aidge::Unfold_OpImpl<DIM>::forward() { ...@@ -44,19 +45,21 @@ void Aidge::Unfold_OpImpl<DIM>::forward() {
static_cast<float>(op.template getAttr<UnfoldAttr::StrideDims>()[1]))); static_cast<float>(op.template getAttr<UnfoldAttr::StrideDims>()[1])));
const DimSize_t outChannels = op.getOutput(0)->dims()[1]; const DimSize_t outChannels = op.getOutput(0)->dims()[1];
for (DimSize_t outC = 0; outC < outChannels; ++outC) { for (DimSize_t n = 0; n < op.getOutput(0)->dims()[0]; ++n) {
const auto inOffsetH = outC % kernelDims[0]; for (DimSize_t outC = 0; outC < outChannels; ++outC) {
const auto inOffsetW = (outC / kernelDims[0]) % kernelDims[1]; const auto inOffsetH = outC % kernelDims[1];
const auto inC = outC / kernelDims[0] / kernelDims[1]; const auto inOffsetW = (outC / kernelDims[1]) % kernelDims[0];
const auto inC = outC / kernelDims[0] / kernelDims[1];
for (DimSize_t outH = 0; outH < outHeight; ++outH) { for (DimSize_t outH = 0; outH < outHeight; ++outH) {
const auto inH = outH * strideDims[1] + inOffsetH * dilationDims[1]; const auto inH = outH * strideDims[0] + inOffsetH * dilationDims[0];
for (DimSize_t outW = 0; outW < outWidth; ++outW) { for (DimSize_t outW = 0; outW < outWidth; ++outW) {
const auto inW = outW * strideDims[0] + inOffsetW * dilationDims[0]; const auto inW = outW * strideDims[1] + inOffsetW * dilationDims[1];
op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr((inC * inHeight + inH) * inWidth + inW), 1, op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(((n * inChannels + inC) * inHeight + inH) * inWidth + inW), 1,
(outC * outHeight + outH) * outWidth + outW); ((n * outChannels + outC) * outHeight + outH) * outWidth + outW);
}
} }
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment