Skip to content
Snippets Groups Projects
Commit 19f7875c authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

fix Slice outputDims

parent e2f9f441
No related branches found
No related tags found
No related merge requests found
...@@ -67,7 +67,7 @@ public: ...@@ -67,7 +67,7 @@ public:
} }
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(inputIdx < 4 && "operator Slice supports only 4 inputs"); assert(inputIdx < 4 && "Slice operator supports only 4 inputs");
assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type");
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
} }
...@@ -75,27 +75,15 @@ public: ...@@ -75,27 +75,15 @@ public:
void computeOutputDims() override final { void computeOutputDims() override final {
if (!mInputs[0]->empty() && !mInputs[1]->empty() && !mInputs[2]->empty()&& !mInputs[3]->empty()) if (!mInputs[0]->empty() && !mInputs[1]->empty() && !mInputs[2]->empty()&& !mInputs[3]->empty())
{ {
DimSize_t nbAxes = mInputs[1]->dims()[0];
const int* axes = static_cast<const int*>(mInputs[1]->getImpl()->rawPtr()); const int* axes = static_cast<const int*>(mInputs[1]->getImpl()->rawPtr());
const int* starts = static_cast<const int*>(mInputs[2]->getImpl()->rawPtr()); const int* starts = static_cast<const int*>(mInputs[2]->getImpl()->rawPtr());
const int* ends = static_cast<const int*>(mInputs[3]->getImpl()->rawPtr()); const int* ends = static_cast<const int*>(mInputs[3]->getImpl()->rawPtr());
DimSize_t nbAxes = mInputs[1]->dims()[0]; std::vector<DimSize_t> outDims = mInputs[0]->dims();
std::vector<DimSize_t> outDims; for(std::size_t i=0; i<nbAxes;++i)
for(std::size_t i=0; i<mInputs[0]->dims().size();++i)
{ {
std::size_t axis = axes[i]>=0?axes[i]:axes[i]+mInputs[0]->nbDims();
const int* idxPos = std::find(axes, axes + nbAxes, static_cast<int>(i)); outDims[axis] = ends[i] - starts[i] + 1;
if(idxPos != (axes + nbAxes))
{
// TODO make sure all indxes are positive before this
size_t idx = static_cast<size_t>(*idxPos);
int startVal = starts[idx];
int endVal = ends[idx];
outDims.push_back(endVal - startVal);
}
else
{
outDims.push_back(mInputs[0]->dims()[i]);
}
} }
mOutput->resize(outDims); mOutput->resize(outDims);
} }
...@@ -114,22 +102,22 @@ public: ...@@ -114,22 +102,22 @@ public:
inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
assert((inputIdx < 4) && "Slice Operator has 4 inputs"); assert((inputIdx < 4) && "Slice operator has 4 inputs");
return mInputs[inputIdx]; return mInputs[inputIdx];
} }
inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
assert((outputIdx == 0) && "Slice Operator has only 1 output"); assert((outputIdx == 0) && "Slice operator has only 1 output");
(void) outputIdx; // avoid unused warning (void) outputIdx; // avoid unused warning
return mOutput; return mOutput;
} }
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
assert(inputIdx < 4 && "operator supports only 4 inputs"); assert(inputIdx < 4 && "Slice operator supports only 4 inputs");
return std::static_pointer_cast<Data>(mInputs[inputIdx]); return std::static_pointer_cast<Data>(mInputs[inputIdx]);
} }
std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
assert(outputIdx == 0 && "operator supports only 1 output"); assert(outputIdx == 0 && "Slice operator supports only 1 output");
(void) outputIdx; // avoid unused warning (void) outputIdx; // avoid unused warning
return std::static_pointer_cast<Data>(mOutput); return std::static_pointer_cast<Data>(mOutput);
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment