From 19f7875c4c486badf9aa9194eff48a27d44b6199 Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Tue, 21 Nov 2023 15:25:28 +0100 Subject: [PATCH] fix Slice outputDims --- include/aidge/operator/Slice.hpp | 32 ++++++++++---------------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp index ccdb66a1e..4a99045e2 100644 --- a/include/aidge/operator/Slice.hpp +++ b/include/aidge/operator/Slice.hpp @@ -67,7 +67,7 @@ public: } void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { - assert(inputIdx < 4 && "operator Slice supports only 4 inputs"); + assert(inputIdx < 4 && "Slice operator supports only 4 inputs"); assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); } @@ -75,27 +75,15 @@ public: void computeOutputDims() override final { if (!mInputs[0]->empty() && !mInputs[1]->empty() && !mInputs[2]->empty()&& !mInputs[3]->empty()) { + DimSize_t nbAxes = mInputs[1]->dims()[0]; const int* axes = static_cast<const int*>(mInputs[1]->getImpl()->rawPtr()); const int* starts = static_cast<const int*>(mInputs[2]->getImpl()->rawPtr()); const int* ends = static_cast<const int*>(mInputs[3]->getImpl()->rawPtr()); - DimSize_t nbAxes = mInputs[1]->dims()[0]; - std::vector<DimSize_t> outDims; - for(std::size_t i=0; i<mInputs[0]->dims().size();++i) + std::vector<DimSize_t> outDims = mInputs[0]->dims(); + for(std::size_t i=0; i<nbAxes;++i) { - - const int* idxPos = std::find(axes, axes + nbAxes, static_cast<int>(i)); - if(idxPos != (axes + nbAxes)) - { - // TODO make sure all indxes are positive before this - size_t idx = static_cast<size_t>(*idxPos); - int startVal = starts[idx]; - int endVal = ends[idx]; - outDims.push_back(endVal - startVal); - } - else - { - outDims.push_back(mInputs[0]->dims()[i]); - } + std::size_t axis = axes[i]>=0?axes[i]:axes[i]+mInputs[0]->nbDims(); + outDims[axis] = ends[i] - starts[i] + 1; } mOutput->resize(outDims); } @@ -114,22 +102,22 @@ public: inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { - assert((inputIdx < 4) && "Slice Operator has 4 inputs"); + assert((inputIdx < 4) && "Slice operator has 4 inputs"); return mInputs[inputIdx]; } inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { - assert((outputIdx == 0) && "Slice Operator has only 1 output"); + assert((outputIdx == 0) && "Slice operator has only 1 output"); (void) outputIdx; // avoid unused warning return mOutput; } std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { - assert(inputIdx < 4 && "operator supports only 4 inputs"); + assert(inputIdx < 4 && "Slice operator supports only 4 inputs"); return std::static_pointer_cast<Data>(mInputs[inputIdx]); } std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { - assert(outputIdx == 0 && "operator supports only 1 output"); + assert(outputIdx == 0 && "Slice operator supports only 1 output"); (void) outputIdx; // avoid unused warning return std::static_pointer_cast<Data>(mOutput); } -- GitLab