diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp index de4207fedd7da627bdb18992c31bb923d9ac7782..a80bac8d9097a91088ead24af2651548e09a8b75 100644 --- a/include/aidge/operator/Slice.hpp +++ b/include/aidge/operator/Slice.hpp @@ -29,17 +29,17 @@ enum class SliceAttr { Starts, Ends, Axes }; class Slice_Op : public OperatorTensor, public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op &)>, - public StaticAttributes<SliceAttr, std::vector<int>, std::vector<int>, std::vector<int>> { + public StaticAttributes<SliceAttr, std::vector<std::int32_t>, std::vector<std::int32_t>, std::vector<std::int32_t>> { public: static const std::string Type; Slice_Op() = delete; - using Attributes_ = StaticAttributes<SliceAttr, std::vector<int>, std::vector<int>, std::vector<int>>; + using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int32_t>, std::vector<std::int32_t>, std::vector<std::int32_t>>; template <SliceAttr e> using attr = typename Attributes_::template attr<e>; - Slice_Op(const std::vector<int>& starts, const std::vector<int>& ends, const std::vector<int>& axes) + Slice_Op(const std::vector<std::int32_t>& starts, const std::vector<std::int32_t>& ends, const std::vector<std::int32_t>& axes) : OperatorTensor(Type, 1, 0, 1), Attributes_(attr<SliceAttr::Starts>(starts), attr<SliceAttr::Ends>(ends), @@ -85,9 +85,9 @@ public: }; -inline std::shared_ptr<Node> Slice(const std::vector<int> starts, - const std::vector<int> ends, - const std::vector<int> axes, +inline std::shared_ptr<Node> Slice(const std::vector<std::int32_t> starts, + const std::vector<std::int32_t> ends, + const std::vector<std::int32_t> axes, const std::string &name = "") { // FIXME: properly handle default w&b initialization in every cases return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes), name); diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp index 63e1048b942f05b43303c3f28465f79821d4ae01..3849f2a17edef1166a3a1ff56679785f354abb2c 100644 --- a/src/operator/Slice.cpp +++ b/src/operator/Slice.cpp @@ -32,9 +32,9 @@ void Aidge::Slice_Op::computeOutputDims() { for(std::size_t i=0; i<nbAxes;++i) { // For each slice operation get the params and cast them to size_t - int axis_ = this->template getAttr<SliceAttr::Axes>()[i]; - int start_ = this->template getAttr<SliceAttr::Starts>()[i]; - int end_ = this->template getAttr<SliceAttr::Ends>()[i]; + std::int64_t axis_ = this->template getAttr<SliceAttr::Axes>()[i]; + std::int64_t start_ = this->template getAttr<SliceAttr::Starts>()[i]; + std::int64_t end_ = this->template getAttr<SliceAttr::Ends>()[i]; std::size_t axis = axis_>=0?axis_:axis_+getInput(0)->nbDims(); std::size_t start = start_>=0?start_:start_+getInput(0)->dims()[axis]; std::size_t end = end_>=0?end_:end_+getInput(0)->dims()[axis];