From 8cce26f50e4479c1c24bcd08dfa2350e02e0020b Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Thu, 23 May 2024 16:43:50 +0200
Subject: [PATCH] fix Slice when step < 0

---
 include/aidge/operator/Slice.hpp | 22 ----------------------
 src/operator/Slice.cpp           | 15 +++++++++++----
 2 files changed, 11 insertions(+), 26 deletions(-)

diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp
index f3afa67a7..57a6aa2ea 100644
--- a/include/aidge/operator/Slice.hpp
+++ b/include/aidge/operator/Slice.hpp
@@ -24,11 +24,6 @@
 #include "aidge/utils/Types.h"
 
 namespace Aidge {
-// class Slice_OpImpl : public OperatorImpl {
-// public:
-//     Slice_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {}
-//     void forward() override;
-// };
 
 enum class SliceAttr { Starts, Ends, Axes, Steps };
 
@@ -109,21 +104,4 @@ template <>
 const char *const EnumStrings<Aidge::SliceAttr>::data[] = { "Starts", "Ends", "Axes", "Steps" };
 }
 
-// namespace Aidge {
-//     class SliceImplForward
-//     : public Registrable<SliceImplForward,
-//                          std::tuple<DataType>,
-//                          void(const Slice_Op::Attrs &, const std::vector<DimSize_t>&, const void *, void *)> {};
-//     template <typename I>
-//     void Slice_forward_kernel(const Slice_Op::Attrs &attrs, const std::vector<DimSize_t>&inputDims, const void *input_, void *output_);
-
-// namespace {
-// static Registrar<SliceImplForward> registrarSliceImplForward_Float32(
-//         {DataType::Float32}, Slice_forward_kernel<float>);
-// static Registrar<SliceImplForward> registrarSliceImplForward_Int32(
-//         {DataType::Int32}, Slice_forward_kernel<int>);
-// static Registrar<SliceImplForward> registrarSliceImplForward_Int64(
-//         {DataType::Float64}, Slice_forward_kernel<double>);
-// }
-// }
 #endif /* AIDGE_CORE_OPERATOR_RELU_H_ */
diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp
index 0a486e37a..070ea8c19 100644
--- a/src/operator/Slice.cpp
+++ b/src/operator/Slice.cpp
@@ -11,7 +11,6 @@
 
 #include "aidge/operator/Slice.hpp"
 
-#include <algorithm>
 #include <cassert>
 #include <cstddef>
 #include <cstdint>
@@ -149,11 +148,19 @@ bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) {
             DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ?
                             static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) :
                             static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis]));
-
-            if(this->template getAttr<SliceAttr::Steps>()[i] == 0) {
+            std::int64_t step = this->template getAttr<SliceAttr::Steps>()[i];
+            if(step == 0) {
                 AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step must be a non-zero value", type());
             }
-            const std::size_t sliceLength = (end - start) / static_cast<DimSize_t>(std::abs(this->template getAttr<SliceAttr::Steps>()[i]));
+            if(step * (end - start) < 0) {
+                if(step < 0) {
+                    AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is negative we must have End < Start", type());
+                }
+                else {
+                    AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is positive we must have Start < End", type());
+                }
+            }
+            const std::size_t sliceLength = static_cast<std::size_t>(std::ceil((static_cast<float>(end) - static_cast<float>(start)) / static_cast<float>(step)));
             // Check if slice length is valid
             if (sliceLength > getInput(0)->dims()[axis])
             {
-- 
GitLab