Skip to content
Snippets Groups Projects
Commit 697c8533 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Do not resize output in forward() as it is done in forwardDims()

parent 42db13e1
No related branches found
No related tags found
No related merge requests found
......@@ -36,33 +36,24 @@ void Aidge::Slice_OpImpl::forward() {
(op.template getAttr<SliceAttr::Starts>().size() == op.template getAttr<SliceAttr::Axes>().size()),
"start, end and axes arguments should be the same size.");
const std::size_t nbDims = op.getInput(0)->nbDims();
const std::vector<std::size_t>& inputDims = op.getInput(0)->dims();
auto outputDims = op.getInput(0)->dims();
const auto nbDims = op.getInput(0)->nbDims();
const auto& inputDims = op.getInput(0)->dims();
const auto& outputDims = op.getOutput(0)->dims();
// compute index of the output's first element
// compute output dimension at the same time (may change between two forward calls)
std::size_t beginning = 0;
const std::size_t nbAxes = op.template getAttr<SliceAttr::Axes>().size();
for (std::size_t i = 0; i < nbAxes; ++i) {
// For each slice operation get the params and cast them to size_t
DimIdx_t axis = op.template getAttr<SliceAttr::Axes>()[i] >= 0 ?
const DimIdx_t axis = op.template getAttr<SliceAttr::Axes>()[i] >= 0 ?
static_cast<DimIdx_t>(op.template getAttr<SliceAttr::Axes>()[i]) :
static_cast<DimIdx_t>(op.template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(inputDims.size()));
DimSize_t start = op.template getAttr<SliceAttr::Starts>()[i] >= 0 ?
const DimSize_t start = op.template getAttr<SliceAttr::Starts>()[i] >= 0 ?
static_cast<DimSize_t>(op.template getAttr<SliceAttr::Starts>()[i]) :
static_cast<DimSize_t>(op.template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(inputDims[axis]));
DimSize_t end = op.template getAttr<SliceAttr::Ends>()[i] >= 0 ?
static_cast<DimSize_t>(op.template getAttr<SliceAttr::Ends>()[i]) :
static_cast<DimSize_t>(op.template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(inputDims[axis]));
const std::size_t stridePostAxis = std::accumulate(inputDims.cbegin()+axis+1, inputDims.cend(), std::size_t(1), std::multiplies<std::size_t>());
beginning += start * stridePostAxis;
const std::size_t sliceLength = end - start;
outputDims[axis] = sliceLength;
}
op.getOutput(0)->resize(outputDims);
// for inputDims = {4,5,5,3} & outputDims = {3,2,2,1}: substractDims = {1,5,5,3}
std::vector<std::size_t> substractedDims = std::vector<std::size_t>(nbDims);
......@@ -195,16 +186,16 @@ bool Aidge::Slice_Op::forwardDims(bool allowDataDependency) {
AIDGE_ASSERT(!this->template getAttr<SliceAttr::Axes>().empty(), "Missing input#3 or Axes attribute");
DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
const DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
std::vector<DimSize_t> outDims = getInput(0)->dims();
for (std::size_t i = 0; i < nbAxes; ++i) {
DimIdx_t axis = this->template getAttr<SliceAttr::Axes>()[i] >= 0 ?
const DimIdx_t axis = this->template getAttr<SliceAttr::Axes>()[i] >= 0 ?
static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i]) :
static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(getInput(0)->nbDims()));
DimSize_t start = this->template getAttr<SliceAttr::Starts>()[i] >= 0 ?
const DimSize_t start = this->template getAttr<SliceAttr::Starts>()[i] >= 0 ?
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i]) :
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis]));
DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ?
const DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ?
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) :
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis]));
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment