Skip to content
Snippets Groups Projects
Commit 8cce26f5 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

fix Slice when step < 0

parent cc38a58b
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!122Add missing attributes to operators
......@@ -24,11 +24,6 @@
#include "aidge/utils/Types.h"
namespace Aidge {
// class Slice_OpImpl : public OperatorImpl {
// public:
// Slice_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {}
// void forward() override;
// };
enum class SliceAttr { Starts, Ends, Axes, Steps };
......@@ -109,21 +104,4 @@ template <>
const char *const EnumStrings<Aidge::SliceAttr>::data[] = { "Starts", "Ends", "Axes", "Steps" };
}
// namespace Aidge {
// class SliceImplForward
// : public Registrable<SliceImplForward,
// std::tuple<DataType>,
// void(const Slice_Op::Attrs &, const std::vector<DimSize_t>&, const void *, void *)> {};
// template <typename I>
// void Slice_forward_kernel(const Slice_Op::Attrs &attrs, const std::vector<DimSize_t>&inputDims, const void *input_, void *output_);
// namespace {
// static Registrar<SliceImplForward> registrarSliceImplForward_Float32(
// {DataType::Float32}, Slice_forward_kernel<float>);
// static Registrar<SliceImplForward> registrarSliceImplForward_Int32(
// {DataType::Int32}, Slice_forward_kernel<int>);
// static Registrar<SliceImplForward> registrarSliceImplForward_Int64(
// {DataType::Float64}, Slice_forward_kernel<double>);
// }
// }
#endif /* AIDGE_CORE_OPERATOR_RELU_H_ */
......@@ -11,7 +11,6 @@
#include "aidge/operator/Slice.hpp"
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <cstdint>
......@@ -149,11 +148,19 @@ bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) {
DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ?
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) :
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis]));
if(this->template getAttr<SliceAttr::Steps>()[i] == 0) {
std::int64_t step = this->template getAttr<SliceAttr::Steps>()[i];
if(step == 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step must be a non-zero value", type());
}
const std::size_t sliceLength = (end - start) / static_cast<DimSize_t>(std::abs(this->template getAttr<SliceAttr::Steps>()[i]));
if(step * (end - start) < 0) {
if(step < 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is negative we must have End < Start", type());
}
else {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is positive we must have Start < End", type());
}
}
const std::size_t sliceLength = static_cast<std::size_t>(std::ceil((static_cast<float>(end) - static_cast<float>(start)) / static_cast<float>(step)));
// Check if slice length is valid
if (sliceLength > getInput(0)->dims()[axis])
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment