From 8efe542dc86ee5618dd2b3d0ee2477ca45e45047 Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Wed, 13 Dec 2023 16:57:21 +0100 Subject: [PATCH] change Slice attr to int64 --- include/aidge/operator/Slice.hpp | 12 ++++++------ src/operator/Slice.cpp | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp index de4207fed..a80bac8d9 100644 --- a/include/aidge/operator/Slice.hpp +++ b/include/aidge/operator/Slice.hpp @@ -29,17 +29,17 @@ enum class SliceAttr { Starts, Ends, Axes }; class Slice_Op : public OperatorTensor, public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op &)>, - public StaticAttributes<SliceAttr, std::vector<int>, std::vector<int>, std::vector<int>> { + public StaticAttributes<SliceAttr, std::vector<std::int32_t>, std::vector<std::int32_t>, std::vector<std::int32_t>> { public: static const std::string Type; Slice_Op() = delete; - using Attributes_ = StaticAttributes<SliceAttr, std::vector<int>, std::vector<int>, std::vector<int>>; + using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int32_t>, std::vector<std::int32_t>, std::vector<std::int32_t>>; template <SliceAttr e> using attr = typename Attributes_::template attr<e>; - Slice_Op(const std::vector<int>& starts, const std::vector<int>& ends, const std::vector<int>& axes) + Slice_Op(const std::vector<std::int32_t>& starts, const std::vector<std::int32_t>& ends, const std::vector<std::int32_t>& axes) : OperatorTensor(Type, 1, 0, 1), Attributes_(attr<SliceAttr::Starts>(starts), attr<SliceAttr::Ends>(ends), @@ -85,9 +85,9 @@ public: }; -inline std::shared_ptr<Node> Slice(const std::vector<int> starts, - const std::vector<int> ends, - const std::vector<int> axes, +inline std::shared_ptr<Node> Slice(const std::vector<std::int32_t> starts, + const std::vector<std::int32_t> ends, + const std::vector<std::int32_t> axes, const std::string &name = "") { // FIXME: properly handle default w&b initialization in every cases return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes), name); diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp index 63e1048b9..3849f2a17 100644 --- a/src/operator/Slice.cpp +++ b/src/operator/Slice.cpp @@ -32,9 +32,9 @@ void Aidge::Slice_Op::computeOutputDims() { for(std::size_t i=0; i<nbAxes;++i) { // For each slice operation get the params and cast them to size_t - int axis_ = this->template getAttr<SliceAttr::Axes>()[i]; - int start_ = this->template getAttr<SliceAttr::Starts>()[i]; - int end_ = this->template getAttr<SliceAttr::Ends>()[i]; + std::int64_t axis_ = this->template getAttr<SliceAttr::Axes>()[i]; + std::int64_t start_ = this->template getAttr<SliceAttr::Starts>()[i]; + std::int64_t end_ = this->template getAttr<SliceAttr::Ends>()[i]; std::size_t axis = axis_>=0?axis_:axis_+getInput(0)->nbDims(); std::size_t start = start_>=0?start_:start_+getInput(0)->dims()[axis]; std::size_t end = end_>=0?end_:end_+getInput(0)->dims()[axis]; -- GitLab