Skip to content
Snippets Groups Projects
Commit 19f7875c authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

fix Slice outputDims

parent e2f9f441
No related branches found
No related tags found
2 merge requests!59Improvements and fixes,!47Vit operators
......@@ -67,7 +67,7 @@ public:
}
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(inputIdx < 4 && "operator Slice supports only 4 inputs");
assert(inputIdx < 4 && "Slice operator supports only 4 inputs");
assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type");
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
}
......@@ -75,27 +75,15 @@ public:
void computeOutputDims() override final {
if (!mInputs[0]->empty() && !mInputs[1]->empty() && !mInputs[2]->empty()&& !mInputs[3]->empty())
{
DimSize_t nbAxes = mInputs[1]->dims()[0];
const int* axes = static_cast<const int*>(mInputs[1]->getImpl()->rawPtr());
const int* starts = static_cast<const int*>(mInputs[2]->getImpl()->rawPtr());
const int* ends = static_cast<const int*>(mInputs[3]->getImpl()->rawPtr());
DimSize_t nbAxes = mInputs[1]->dims()[0];
std::vector<DimSize_t> outDims;
for(std::size_t i=0; i<mInputs[0]->dims().size();++i)
std::vector<DimSize_t> outDims = mInputs[0]->dims();
for(std::size_t i=0; i<nbAxes;++i)
{
const int* idxPos = std::find(axes, axes + nbAxes, static_cast<int>(i));
if(idxPos != (axes + nbAxes))
{
// TODO make sure all indxes are positive before this
size_t idx = static_cast<size_t>(*idxPos);
int startVal = starts[idx];
int endVal = ends[idx];
outDims.push_back(endVal - startVal);
}
else
{
outDims.push_back(mInputs[0]->dims()[i]);
}
std::size_t axis = axes[i]>=0?axes[i]:axes[i]+mInputs[0]->nbDims();
outDims[axis] = ends[i] - starts[i] + 1;
}
mOutput->resize(outDims);
}
......@@ -114,22 +102,22 @@ public:
inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
assert((inputIdx < 4) && "Slice Operator has 4 inputs");
assert((inputIdx < 4) && "Slice operator has 4 inputs");
return mInputs[inputIdx];
}
inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
assert((outputIdx == 0) && "Slice Operator has only 1 output");
assert((outputIdx == 0) && "Slice operator has only 1 output");
(void) outputIdx; // avoid unused warning
return mOutput;
}
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
assert(inputIdx < 4 && "operator supports only 4 inputs");
assert(inputIdx < 4 && "Slice operator supports only 4 inputs");
return std::static_pointer_cast<Data>(mInputs[inputIdx]);
}
std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
assert(outputIdx == 0 && "operator supports only 1 output");
assert(outputIdx == 0 && "Slice operator supports only 1 output");
(void) outputIdx; // avoid unused warning
return std::static_pointer_cast<Data>(mOutput);
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment