From 117af588e5d856f687405c1ad4652eac78c3846d Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Wed, 13 Dec 2023 16:47:07 +0000 Subject: [PATCH] Cherry-pick Slice_Op changes from vit_operator branch --- include/aidge/operator/Slice.hpp | 10 ++++++---- src/operator/Slice.cpp | 24 +++++++++++++++++++++++- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp index 5968fdeb4..26abaf291 100644 --- a/include/aidge/operator/Slice.hpp +++ b/include/aidge/operator/Slice.hpp @@ -29,17 +29,17 @@ enum class SliceAttr { Beginning, SliceDims }; class Slice_Op : public OperatorTensor, public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op &)>, - public StaticAttributes<SliceAttr, std::size_t, std::vector<DimSize_t>> { + public StaticAttributes<SliceAttr, std::vector<std::int32_t>, std::vector<std::int32_t>, std::vector<std::int32_t>> { public: static const std::string Type; Slice_Op() = delete; - using Attributes_ = StaticAttributes<SliceAttr, std::size_t, std::vector<DimSize_t>>; + using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int32_t>, std::vector<std::int32_t>, std::vector<std::int32_t>>; template <SliceAttr e> using attr = typename Attributes_::template attr<e>; - Slice_Op(const std::size_t beginningPos, const std::vector<DimSize_t> sliceDims) + Slice_Op(const std::vector<std::int32_t>& starts, const std::vector<std::int32_t>& ends, const std::vector<std::int32_t>& axes) : OperatorTensor(Type, 1, 0, 1), Attributes_(attr<SliceAttr::Beginning>(beginningPos), attr<SliceAttr::SliceDims>(sliceDims)) @@ -107,7 +107,9 @@ public: }; -inline std::shared_ptr<Node> Slice(const std::size_t beginningPos, const std::vector<DimSize_t> sliceDims, +inline std::shared_ptr<Node> Slice(const std::vector<std::int32_t> starts, + const std::vector<std::int32_t> ends, + const std::vector<std::int32_t> axes, const std::string &name = "") { // FIXME: properly handle default w&b initialization in every cases return std::make_shared<Node>(std::make_shared<Slice_Op>(beginningPos, sliceDims), name); diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp index a25e290df..9ac6bd2c0 100644 --- a/src/operator/Slice.cpp +++ b/src/operator/Slice.cpp @@ -12,5 +12,27 @@ #include <string> #include "aidge/operator/Slice.hpp" +#include "aidge/utils/Types.h" +#include "aidge/utils/ErrorHandling.hpp" -const std::string Aidge::Slice_Op::Type = "Slice"; \ No newline at end of file +const std::string Aidge::Slice_Op::Type = "Slice"; + +void Aidge::Slice_Op::computeOutputDims() { + // check input have been associated + if (!getInput(0) || (getInput(0)->empty())) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor"); + } + + DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size(); + std::vector<DimSize_t> outDims = getInput(0)->dims(); + for(std::size_t i=0; i<nbAxes;++i) + { + // For each slice operation get the params and cast them to size_t + std::int64_t axis_ = this->template getAttr<SliceAttr::Axes>()[i]; + std::int64_t start_ = this->template getAttr<SliceAttr::Starts>()[i]; + std::int64_t end_ = this->template getAttr<SliceAttr::Ends>()[i]; + std::size_t axis = axis_>=0?axis_:axis_+getInput(0)->nbDims(); + std::size_t start = start_>=0?start_:start_+getInput(0)->dims()[axis]; + std::size_t end = end_>=0?end_:end_+getInput(0)->dims()[axis]; + } +} -- GitLab