Skip to content
Snippets Groups Projects
Commit 8cce26f5 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

fix Slice when step < 0

parent cc38a58b
No related branches found
No related tags found
No related merge requests found
...@@ -24,11 +24,6 @@ ...@@ -24,11 +24,6 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
namespace Aidge { namespace Aidge {
// class Slice_OpImpl : public OperatorImpl {
// public:
// Slice_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {}
// void forward() override;
// };
enum class SliceAttr { Starts, Ends, Axes, Steps }; enum class SliceAttr { Starts, Ends, Axes, Steps };
...@@ -109,21 +104,4 @@ template <> ...@@ -109,21 +104,4 @@ template <>
const char *const EnumStrings<Aidge::SliceAttr>::data[] = { "Starts", "Ends", "Axes", "Steps" }; const char *const EnumStrings<Aidge::SliceAttr>::data[] = { "Starts", "Ends", "Axes", "Steps" };
} }
// namespace Aidge {
// class SliceImplForward
// : public Registrable<SliceImplForward,
// std::tuple<DataType>,
// void(const Slice_Op::Attrs &, const std::vector<DimSize_t>&, const void *, void *)> {};
// template <typename I>
// void Slice_forward_kernel(const Slice_Op::Attrs &attrs, const std::vector<DimSize_t>&inputDims, const void *input_, void *output_);
// namespace {
// static Registrar<SliceImplForward> registrarSliceImplForward_Float32(
// {DataType::Float32}, Slice_forward_kernel<float>);
// static Registrar<SliceImplForward> registrarSliceImplForward_Int32(
// {DataType::Int32}, Slice_forward_kernel<int>);
// static Registrar<SliceImplForward> registrarSliceImplForward_Int64(
// {DataType::Float64}, Slice_forward_kernel<double>);
// }
// }
#endif /* AIDGE_CORE_OPERATOR_RELU_H_ */ #endif /* AIDGE_CORE_OPERATOR_RELU_H_ */
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
#include "aidge/operator/Slice.hpp" #include "aidge/operator/Slice.hpp"
#include <algorithm>
#include <cassert> #include <cassert>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
...@@ -149,11 +148,19 @@ bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) { ...@@ -149,11 +148,19 @@ bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) {
DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ? DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ?
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) : static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) :
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis])); static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis]));
std::int64_t step = this->template getAttr<SliceAttr::Steps>()[i];
if(this->template getAttr<SliceAttr::Steps>()[i] == 0) { if(step == 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step must be a non-zero value", type()); AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step must be a non-zero value", type());
} }
const std::size_t sliceLength = (end - start) / static_cast<DimSize_t>(std::abs(this->template getAttr<SliceAttr::Steps>()[i])); if(step * (end - start) < 0) {
if(step < 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is negative we must have End < Start", type());
}
else {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is positive we must have Start < End", type());
}
}
const std::size_t sliceLength = static_cast<std::size_t>(std::ceil((static_cast<float>(end) - static_cast<float>(start)) / static_cast<float>(step)));
// Check if slice length is valid // Check if slice length is valid
if (sliceLength > getInput(0)->dims()[axis]) if (sliceLength > getInput(0)->dims()[axis])
{ {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment