diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp index 5bb07ae01d8f076891a803698d2b3f489d90b462..bf98736f0cab95b4ad618d1bee0850520144428d 100644 --- a/include/aidge/operator/Slice.hpp +++ b/include/aidge/operator/Slice.hpp @@ -31,6 +31,10 @@ public: void forward() override; }; +// Implementation note: +// If start or end are out of bound then it takes the max value for the given axe. +// Example Slice with start=1, end=1000, axes=0 for tensor [0, 1, 2, 3] +// Will return [1, 2, 3] enum class SliceAttr { Starts, Ends, Axes, Steps }; class Slice_Op diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp index 02dcad58c47fb804f50b9eb2e20be45a12e73fae..ab2d5b264ab0605f4b414287381c059e1289ce68 100644 --- a/src/operator/Slice.cpp +++ b/src/operator/Slice.cpp @@ -253,12 +253,17 @@ bool Aidge::Slice_Op::forwardDims(bool allowDataDependency) { const DimIdx_t axis = this->axes()[i] >= 0 ? static_cast<DimIdx_t>(this->axes()[i]) : static_cast<DimIdx_t>(this->axes()[i] + static_cast<DimIdx_t>(getInput(0)->nbDims())); - const DimSize_t start = this->starts()[i] >= 0 ? + DimSize_t start = this->starts()[i] >= 0 ? static_cast<DimSize_t>(this->starts()[i]) : static_cast<DimSize_t>(this->starts()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis])); - const DimSize_t end = this->ends()[i] >= 0 ? + // Clamp start to the range [0, axis_dim] + start = std::max(static_cast<DimSize_t>(0), std::min(start, getInput(0)->dims()[axis]-1)); + + DimSize_t end = this->ends()[i] >= 0 ? static_cast<DimSize_t>(this->ends()[i]) : static_cast<DimSize_t>(this->ends()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis])); + // Clamp end to the range [0, axis_dim] + end = std::max(static_cast<DimSize_t>(0), std::min(end, getInput(0)->dims()[axis]-1)); const std::int64_t step = this->steps()[i]; AIDGE_ASSERT(step != 0, "Slice_Op: Step ({}) must have a non-zero value on axis {}!", this->steps(), axis); @@ -309,4 +314,4 @@ std::shared_ptr<Aidge::Node> Aidge::Slice(const std::vector<std::int64_t>& start const std::vector<std::int64_t>& steps, const std::string &name) { return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes, steps), name); -} \ No newline at end of file +}