From 95ba2718a641617dd904a7d9e87a0f4ae7848ad5 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Wed, 11 Oct 2023 12:13:07 +0000 Subject: [PATCH] [Add] Slice Operator --- include/aidge/operator/Slice.hpp | 185 +++++++++++++++++++++++++++++++ 1 file changed, 185 insertions(+) create mode 100644 include/aidge/operator/Slice.hpp diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp new file mode 100644 index 000000000..839d873a0 --- /dev/null +++ b/include/aidge/operator/Slice.hpp @@ -0,0 +1,185 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_SLICE_H_ +#define AIDGE_CORE_OPERATOR_SLICE_H_ + +#include <memory> +#include <vector> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Data.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/StaticAttributes.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { +enum class SliceAttr { Beginning, SliceDims }; + +template <DimIdx_t DIM> +class Slice_Op + : public Operator, + public Registrable<Slice_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op<DIM> &)>, + public StaticAttributes<SliceAttr, std::size_t, std::array<DimSize_t, DIM>> { +public: + // FIXME: change accessibility + std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + +public: + static constexpr const char *Type = "Slice"; + + Slice_Op() = delete; + + using Attributes_ = StaticAttributes<SliceAttr, std::size_t, std::array<DimSize_t, DIM>>; + template <SliceAttr e> + using attr = typename Attributes_::template attr<e>; + + Slice_Op(std::size_t beginningPos, std::array<DimSize_t, DIM> sliceDims) + : Operator(Type), + Attributes_(attr<SliceAttr::Beginning>(beginningPos), + attr<SliceAttr::SliceDims>(sliceDims)) + { + setDatatype(DataType::Float32); + } + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its + * input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Slice_Op(const Slice_Op &op) + : Operator(Type), + Attributes_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Slice_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) + : nullptr; + } + +public: + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Slice_Op + */ + std::shared_ptr<Operator> clone() const override { return std::make_shared<Slice_Op>(*this); } + + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx == 0 && "operator supports only 1 input"); + (void)inputIdx; // avoid unused warning + assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); + mInput = std::dynamic_pointer_cast<Tensor>(data); + } + + void computeOutputDims() override final { + if (!mInput->empty()) { + // Check input dimensions is compatible with slice dimensions + if (mInput->nbDims() != DIM) { + printf("Error: input and slice dimensions are not the same size.\n"); + exit(-1); + } + std::array<DimSize_t, DIM> outputDims; + + // Check that the sliced Tensor is actually part of the input Tensor + // For a 5*5 tensor ('x') and a 3*3 slice kernel ('o'): + // xxxxx xxxxx + // xxxxx xxxxx + // xxooo --> ok xxxoo --> out of bound + // xxooo xxxoo + // xxooo xxxoo + std::vector<std::size_t> beginningCoords = mInput->getCoord(this->template getAttr<SliceAttr::Beginning>()); + for (std::size_t i = 0; i < DIM; ++i) { + if (beginningCoords[i] + this->template getAttr<SliceAttr::SliceDims>()[i] >= mInput->dims()[i]) { + printf("ROI of Slice operator out of bounds"); + exit(-1); + } else { + outputDims[i] = this->template getAttr<SliceAttr::SliceDims>()[i]; + } + } + + mOutput->resize(outputDims); + } + } + + bool outputDimsForwarded() const override final { return !(mOutput->empty()); } + + inline Tensor &input(const IOIndex_t /*inputIdx*/) const override final { + return *(mInput.get()); + } + inline Tensor &output(const IOIndex_t /*outputIdx*/) const override final { + return *(mOutput.get()); + } + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert((inputIdx == 0) && "Slice Operator has only 1 input"); + (void)inputIdx; // avoid unused warning + return mInput; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert((outputIdx == 0) && "Slice Operator has only 1 output"); + (void)outputIdx; // avoid unused warning + return mOutput; + } + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx == 0 && "operator supports only 1 input"); + (void)inputIdx; // avoid unused warning + return std::static_pointer_cast<Data>(mInput); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + (void)outputIdx; // avoid unused warning + return mOutput; + } + + void setBackend(const std::string &name) { + mImpl = Registrar<Slice_Op>::create(name)(*this); + mOutput->setBackend(name); + + // FIXME: temporary workaround + mInput->setBackend(name); + } + void setDatatype(const DataType &datatype) { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + mInput->setDatatype(datatype); + } + + inline IOIndex_t nbInputs() const noexcept override final { return 1; } + inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } + inline IOIndex_t nbOutputs() const noexcept override final { return 1; } +}; + +template <DimIdx_t DIM> +inline std::shared_ptr<Node> Slice(std::size_t beginningPos, std::array<DimSize_t, DIM> sliceDims, + const std::string &name = "") { + // FIXME: properly handle default w&b initialization in every cases + return std::make_shared<Node>(std::make_shared<Slice_Op<DIM>>( beginningPos, sliceDims), name); +} + +template <DimIdx_t DIM> +inline std::shared_ptr<Node> Slice(std::size_t beginningPos, DimSize_t const (&sliceDims)[DIM], const std::string& name = "") { + return Slice(beginningPos, to_array(sliceDims), name); +} +} // namespace Aidge + +namespace { +template <> +const char *const EnumStrings<Aidge::SliceAttr>::data[] = { "Beginning", "SliceDims" }; +} + +#endif /* AIDGE_CORE_OPERATOR_RELU_H_ */ -- GitLab