From 8cce26f50e4479c1c24bcd08dfa2350e02e0020b Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Thu, 23 May 2024 16:43:50 +0200 Subject: [PATCH] fix Slice when step < 0 --- include/aidge/operator/Slice.hpp | 22 ---------------------- src/operator/Slice.cpp | 15 +++++++++++---- 2 files changed, 11 insertions(+), 26 deletions(-) diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp index f3afa67a7..57a6aa2ea 100644 --- a/include/aidge/operator/Slice.hpp +++ b/include/aidge/operator/Slice.hpp @@ -24,11 +24,6 @@ #include "aidge/utils/Types.h" namespace Aidge { -// class Slice_OpImpl : public OperatorImpl { -// public: -// Slice_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {} -// void forward() override; -// }; enum class SliceAttr { Starts, Ends, Axes, Steps }; @@ -109,21 +104,4 @@ template <> const char *const EnumStrings<Aidge::SliceAttr>::data[] = { "Starts", "Ends", "Axes", "Steps" }; } -// namespace Aidge { -// class SliceImplForward -// : public Registrable<SliceImplForward, -// std::tuple<DataType>, -// void(const Slice_Op::Attrs &, const std::vector<DimSize_t>&, const void *, void *)> {}; -// template <typename I> -// void Slice_forward_kernel(const Slice_Op::Attrs &attrs, const std::vector<DimSize_t>&inputDims, const void *input_, void *output_); - -// namespace { -// static Registrar<SliceImplForward> registrarSliceImplForward_Float32( -// {DataType::Float32}, Slice_forward_kernel<float>); -// static Registrar<SliceImplForward> registrarSliceImplForward_Int32( -// {DataType::Int32}, Slice_forward_kernel<int>); -// static Registrar<SliceImplForward> registrarSliceImplForward_Int64( -// {DataType::Float64}, Slice_forward_kernel<double>); -// } -// } #endif /* AIDGE_CORE_OPERATOR_RELU_H_ */ diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp index 0a486e37a..070ea8c19 100644 --- a/src/operator/Slice.cpp +++ b/src/operator/Slice.cpp @@ -11,7 +11,6 @@ #include "aidge/operator/Slice.hpp" -#include <algorithm> #include <cassert> #include <cstddef> #include <cstdint> @@ -149,11 +148,19 @@ bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) { DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ? static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) : static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis])); - - if(this->template getAttr<SliceAttr::Steps>()[i] == 0) { + std::int64_t step = this->template getAttr<SliceAttr::Steps>()[i]; + if(step == 0) { AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step must be a non-zero value", type()); } - const std::size_t sliceLength = (end - start) / static_cast<DimSize_t>(std::abs(this->template getAttr<SliceAttr::Steps>()[i])); + if(step * (end - start) < 0) { + if(step < 0) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is negative we must have End < Start", type()); + } + else { + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is positive we must have Start < End", type()); + } + } + const std::size_t sliceLength = static_cast<std::size_t>(std::ceil((static_cast<float>(end) - static_cast<float>(start)) / static_cast<float>(step))); // Check if slice length is valid if (sliceLength > getInput(0)->dims()[axis]) { -- GitLab