diff --git a/src/operator/AvgPooling.cpp b/src/operator/AvgPooling.cpp index 7561297bad7e38f56861776f81e38ee2260dbee8..966063cd05eb728fc6a976e9ae7949da36a9da5c 100644 --- a/src/operator/AvgPooling.cpp +++ b/src/operator/AvgPooling.cpp @@ -75,11 +75,12 @@ bool Aidge::AvgPooling_Op<DIM>::forwardDims(bool /*allowDataDependency*/) { const auto strideDim = mAttributes->template getAttr<AvgPoolingAttr::StrideDims>()[dim]; const auto dilationDim = mAttributes->template getAttr<AvgPoolingAttr::Dilations>()[dim]; - outputDims[dim+2] = 1 + static_cast<DimSize_t>( - roundingFunction(static_cast<float>(inputDims[dim+2] - - (kernelDim - 1) * dilationDim - 1) / - static_cast<float>(strideDim))); + const float effective_size = static_cast<float>(inputDims[dim+2] - (kernelDim - 1) * dilationDim - 1); + const float rounded_val = roundingFunction(effective_size / static_cast<float>(strideDim)); + + outputDims[dim+2] = 1 + std::max(0, static_cast<int>(rounded_val)); } + outputDims[1] = inputDims[1]; outputDims[0] = inputDims[0]; getOutput(0)->resize(outputDims);