diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp index f3afa67a7614b1e525012fca4f4ecead4546846a..57a6aa2eafede5c5d0e64819b16f6a186de38306 100644 --- a/include/aidge/operator/Slice.hpp +++ b/include/aidge/operator/Slice.hpp @@ -24,11 +24,6 @@ #include "aidge/utils/Types.h" namespace Aidge { -// class Slice_OpImpl : public OperatorImpl { -// public: -// Slice_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {} -// void forward() override; -// }; enum class SliceAttr { Starts, Ends, Axes, Steps }; @@ -109,21 +104,4 @@ template <> const char *const EnumStrings<Aidge::SliceAttr>::data[] = { "Starts", "Ends", "Axes", "Steps" }; } -// namespace Aidge { -// class SliceImplForward -// : public Registrable<SliceImplForward, -// std::tuple<DataType>, -// void(const Slice_Op::Attrs &, const std::vector<DimSize_t>&, const void *, void *)> {}; -// template <typename I> -// void Slice_forward_kernel(const Slice_Op::Attrs &attrs, const std::vector<DimSize_t>&inputDims, const void *input_, void *output_); - -// namespace { -// static Registrar<SliceImplForward> registrarSliceImplForward_Float32( -// {DataType::Float32}, Slice_forward_kernel<float>); -// static Registrar<SliceImplForward> registrarSliceImplForward_Int32( -// {DataType::Int32}, Slice_forward_kernel<int>); -// static Registrar<SliceImplForward> registrarSliceImplForward_Int64( -// {DataType::Float64}, Slice_forward_kernel<double>); -// } -// } #endif /* AIDGE_CORE_OPERATOR_RELU_H_ */ diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp index 0a486e37a36b02c0a8252036ee34ca805bae725c..070ea8c1991fa936a4f9421cad1b266f24fe87a2 100644 --- a/src/operator/Slice.cpp +++ b/src/operator/Slice.cpp @@ -11,7 +11,6 @@ #include "aidge/operator/Slice.hpp" -#include <algorithm> #include <cassert> #include <cstddef> #include <cstdint> @@ -149,11 +148,19 @@ bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) { DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ? static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) : static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis])); - - if(this->template getAttr<SliceAttr::Steps>()[i] == 0) { + std::int64_t step = this->template getAttr<SliceAttr::Steps>()[i]; + if(step == 0) { AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step must be a non-zero value", type()); } - const std::size_t sliceLength = (end - start) / static_cast<DimSize_t>(std::abs(this->template getAttr<SliceAttr::Steps>()[i])); + if(step * (end - start) < 0) { + if(step < 0) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is negative we must have End < Start", type()); + } + else { + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is positive we must have Start < End", type()); + } + } + const std::size_t sliceLength = static_cast<std::size_t>(std::ceil((static_cast<float>(end) - static_cast<float>(start)) / static_cast<float>(step))); // Check if slice length is valid if (sliceLength > getInput(0)->dims()[axis]) {