diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp index 7bdbd8099ab79c9f1714989dc41cfc0893427bc9..34fabe3baffc893ed9ddf9181c164ffda24a126f 100644 --- a/include/aidge/operator/Slice.hpp +++ b/include/aidge/operator/Slice.hpp @@ -26,21 +26,20 @@ namespace Aidge { enum class SliceAttr { Beginning, SliceDims }; -template <DimIdx_t DIM> class Slice_Op : public OperatorTensor, - public Registrable<Slice_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op<DIM> &)>, - public StaticAttributes<SliceAttr, std::size_t, std::array<DimSize_t, DIM>> { + public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op &)>, + public StaticAttributes<SliceAttr, std::size_t, std::vector<DimSize_t>> { public: static constexpr const char *Type = "Slice"; Slice_Op() = delete; - using Attributes_ = StaticAttributes<SliceAttr, std::size_t, std::array<DimSize_t, DIM>>; + using Attributes_ = StaticAttributes<SliceAttr, std::size_t, std::vector<DimSize_t>>; template <SliceAttr e> using attr = typename Attributes_::template attr<e>; - Slice_Op(std::size_t beginningPos, std::array<DimSize_t, DIM> sliceDims) + Slice_Op(const std::size_t beginningPos, const std::vector<DimSize_t> sliceDims) : OperatorTensor(Type, 1, 0, 1), Attributes_(attr<SliceAttr::Beginning>(beginningPos), attr<SliceAttr::SliceDims>(sliceDims)) @@ -55,7 +54,7 @@ public: : OperatorTensor(op), Attributes_(op) { - mImpl = op.mImpl ? Registrar<Slice_Op<DIM>>::create(mOutputs[0]->getImpl()->backend())(*this) + mImpl = op.mImpl ? Registrar<Slice_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr; } @@ -70,12 +69,8 @@ public: if (!getInput(0) || (getInput(0)->empty())) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor"); } - // Check input dimensions is compatible with slice dimensions - if (getInput(0)->nbDims() != DIM) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Error: input and slice dimensions are not the same size."); - } - std::array<DimSize_t, DIM> outputDims; - const std::array<DimSize_t, DIM> inputDims = getInput(0)->template dims<DIM>(); + std::vector<DimSize_t> outputDims = std::vector<DimSize_t>(getInput(0)->nbDims()); + const std::vector<DimSize_t> inputDims = getInput(0)->dims(); // Check that the sliced Tensor is actually part of the input Tensor // For a 5*5 tensor ('x') and a 3*3 slice kernel ('o'): @@ -85,7 +80,7 @@ public: // xxooo xxxoo // xxooo xxxoo std::vector<std::size_t> beginningCoords = mInputs[0]->getCoord(this->template getAttr<SliceAttr::Beginning>()); - for (std::size_t i = 0; i < DIM; ++i) { + for (std::size_t i = 0; i < getInput(0)->nbDims(); ++i) { if (beginningCoords[i] + this->template getAttr<SliceAttr::SliceDims>()[i] > inputDims[i]) { AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds"); } else { @@ -111,16 +106,11 @@ public: } }; -template <std::size_t DIM> -inline std::shared_ptr<Node> Slice(std::size_t beginningPos, std::array<DimSize_t, DIM> sliceDims, + +inline std::shared_ptr<Node> Slice(const std::size_t beginningPos, const std::vector<DimSize_t> sliceDims, const std::string &name = "") { // FIXME: properly handle default w&b initialization in every cases - return std::make_shared<Node>(std::make_shared<Slice_Op<DIM>>( beginningPos, sliceDims), name); -} - -template <DimIdx_t DIM> -inline std::shared_ptr<Node> Slice(std::size_t beginningPos, DimSize_t const (&sliceDims)[DIM], const std::string& name = "") { - return Slice(beginningPos, to_array(sliceDims), name); + return std::make_shared<Node>(std::make_shared<Slice_Op>(beginningPos, sliceDims), name); } } // namespace Aidge