From 697c85330d9e4c85c1036bde8eb60c4dd3c88f5a Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Sun, 26 May 2024 17:00:06 +0200 Subject: [PATCH] Do not resize output in forward() as it is done in forwardDims() --- src/operator/Slice.cpp | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp index 8a9f5cbbf..e3ac4e774 100644 --- a/src/operator/Slice.cpp +++ b/src/operator/Slice.cpp @@ -36,33 +36,24 @@ void Aidge::Slice_OpImpl::forward() { (op.template getAttr<SliceAttr::Starts>().size() == op.template getAttr<SliceAttr::Axes>().size()), "start, end and axes arguments should be the same size."); - const std::size_t nbDims = op.getInput(0)->nbDims(); - - const std::vector<std::size_t>& inputDims = op.getInput(0)->dims(); - auto outputDims = op.getInput(0)->dims(); + const auto nbDims = op.getInput(0)->nbDims(); + const auto& inputDims = op.getInput(0)->dims(); + const auto& outputDims = op.getOutput(0)->dims(); // compute index of the output's first element - // compute output dimension at the same time (may change between two forward calls) std::size_t beginning = 0; const std::size_t nbAxes = op.template getAttr<SliceAttr::Axes>().size(); for (std::size_t i = 0; i < nbAxes; ++i) { // For each slice operation get the params and cast them to size_t - DimIdx_t axis = op.template getAttr<SliceAttr::Axes>()[i] >= 0 ? + const DimIdx_t axis = op.template getAttr<SliceAttr::Axes>()[i] >= 0 ? static_cast<DimIdx_t>(op.template getAttr<SliceAttr::Axes>()[i]) : static_cast<DimIdx_t>(op.template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(inputDims.size())); - DimSize_t start = op.template getAttr<SliceAttr::Starts>()[i] >= 0 ? + const DimSize_t start = op.template getAttr<SliceAttr::Starts>()[i] >= 0 ? static_cast<DimSize_t>(op.template getAttr<SliceAttr::Starts>()[i]) : static_cast<DimSize_t>(op.template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(inputDims[axis])); - DimSize_t end = op.template getAttr<SliceAttr::Ends>()[i] >= 0 ? - static_cast<DimSize_t>(op.template getAttr<SliceAttr::Ends>()[i]) : - static_cast<DimSize_t>(op.template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(inputDims[axis])); const std::size_t stridePostAxis = std::accumulate(inputDims.cbegin()+axis+1, inputDims.cend(), std::size_t(1), std::multiplies<std::size_t>()); beginning += start * stridePostAxis; - const std::size_t sliceLength = end - start; - outputDims[axis] = sliceLength; } - op.getOutput(0)->resize(outputDims); - // for inputDims = {4,5,5,3} & outputDims = {3,2,2,1}: substractDims = {1,5,5,3} std::vector<std::size_t> substractedDims = std::vector<std::size_t>(nbDims); @@ -195,16 +186,16 @@ bool Aidge::Slice_Op::forwardDims(bool allowDataDependency) { AIDGE_ASSERT(!this->template getAttr<SliceAttr::Axes>().empty(), "Missing input#3 or Axes attribute"); - DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size(); + const DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size(); std::vector<DimSize_t> outDims = getInput(0)->dims(); for (std::size_t i = 0; i < nbAxes; ++i) { - DimIdx_t axis = this->template getAttr<SliceAttr::Axes>()[i] >= 0 ? + const DimIdx_t axis = this->template getAttr<SliceAttr::Axes>()[i] >= 0 ? static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i]) : static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(getInput(0)->nbDims())); - DimSize_t start = this->template getAttr<SliceAttr::Starts>()[i] >= 0 ? + const DimSize_t start = this->template getAttr<SliceAttr::Starts>()[i] >= 0 ? static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i]) : static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis])); - DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ? + const DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ? static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) : static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis])); -- GitLab