From 1520a9a083afd2b5bf2d33025d4e16d4cbbce53a Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Fri, 10 Nov 2023 09:31:02 +0000 Subject: [PATCH] Propagate the new class hierarchy consequences to Conv and AvgPooling --- include/aidge/operator/AvgPooling.hpp | 79 ++++----------------------- include/aidge/operator/Conv.hpp | 69 +++-------------------- src/backend/OperatorImpl.cpp | 12 ++-- 3 files changed, 25 insertions(+), 135 deletions(-) diff --git a/include/aidge/operator/AvgPooling.hpp b/include/aidge/operator/AvgPooling.hpp index dfcd0d5b3..490782331 100644 --- a/include/aidge/operator/AvgPooling.hpp +++ b/include/aidge/operator/AvgPooling.hpp @@ -19,7 +19,7 @@ #include "aidge/data/Tensor.hpp" #include "aidge/graph/Node.hpp" -#include "aidge/operator/Operator.hpp" +#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/Registrar.hpp" @@ -29,15 +29,11 @@ namespace Aidge { enum class AvgPoolingAttr { StrideDims, KernelDims }; template <DimIdx_t DIM> -class AvgPooling_Op : public Operator, +class AvgPooling_Op : public OperatorTensor, public Registrable<AvgPooling_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const AvgPooling_Op<DIM> &)>, public StaticAttributes<AvgPoolingAttr, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>> { -private: - // FIXME: change accessibility - std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); - const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); public: static constexpr const char *Type = "AvgPooling"; @@ -52,10 +48,10 @@ public: constexpr AvgPooling_Op(const std::array<DimSize_t, DIM> &kernel_dims, const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1)) - : Operator(Type), + : OperatorTensor(Type, 1, 0, 1), Attributes_(attr<AvgPoolingAttr::StrideDims>(stride_dims), attr<AvgPoolingAttr::KernelDims>(kernel_dims)) { - setDatatype(DataType::Float32); + setDataType(DataType::Float32); } /** @@ -63,12 +59,12 @@ public: * @param op Operator to copy. */ AvgPooling_Op(const AvgPooling_Op<DIM>& op) - : Operator(Type), + : OperatorTensor(Type, 1, 0, 1), Attributes_(op), mOutput(std::make_shared<Tensor>(*op.mOutput)) { // cpy-ctor - setDatatype(op.mOutput->dataType()); + setDataType(op.mOutput->dataType()); mImpl = op.mImpl ? Registrar<AvgPooling_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; } @@ -80,64 +76,23 @@ public: return std::make_shared<AvgPooling_Op<DIM>>(*this); } - void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { - assert(inputIdx < 1 && "operators supports only 3 inputs"); - (void) inputIdx; // avoid unused warning - assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); - - mInput = std::dynamic_pointer_cast<Tensor>(data); - } void computeOutputDims() override final { - if (!mInput->empty()) { + if (!*mInputs[0]->empty()) { std::array<DimSize_t, DIM + 2> outputDims = {}; for (std::size_t dim = 0; dim < this->template getAttr<AvgPoolingAttr::KernelDims>().size() ; ++dim) { outputDims[dim+2] = 1 + static_cast<DimSize_t>( - std::floor(static_cast<float>(mInput->dims()[dim+2] - + std::floor(static_cast<float>(*mInputs[0]->dims()[dim+2] - this->template getAttr<AvgPoolingAttr::KernelDims>()[dim]) / static_cast<float>(this->template getAttr<AvgPoolingAttr::StrideDims>()[dim]))); } - outputDims[1] = mInput->dims()[1]; - outputDims[0] = mInput->dims()[0]; - mOutput->resize(outputDims); + outputDims[1] = *mInputs[0]->dims()[1]; + outputDims[0] = *mInputs[0]->dims()[0]; + mOutputs[0]->resize(outputDims); } } - bool outputDimsForwarded() const override final { return !(mOutput->empty()); } - - - inline Tensor& input(const IOIndex_t inputIdx) const override final { - assert(inputIdx == 0 && "operators supports only 1 inputs"); - (void) inputIdx; // avoid unused warning - return *(mInput.get()); - } - inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } - - - inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { - assert(inputIdx == 0 && "AvgPooling Operators supports only 1 inputs"); - (void) inputIdx; // avoid unused warning - return mInput; - } - inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { - assert(outputIdx == 0 && "AvgPooling Operators has only 1 outputs"); - (void) outputIdx; // avoid unused warning - return mOutput; - } - - - std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { - assert(inputIdx == 0 && "operators supports only 1 inputs"); - (void) inputIdx; // avoid unused warning - return std::static_pointer_cast<Data>(mInput); - } - std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { - assert(outputIdx == 0 && "operator supports only 1 output"); - (void) outputIdx; // avoid unused warning - return std::static_pointer_cast<Data>(mOutput); - } - void setBackend(const std::string &name) override { mImpl = Registrar<AvgPooling_Op<DIM>>::create(name)(*this); @@ -147,16 +102,6 @@ public: mInput->setBackend(name); } - void setDatatype(const DataType &datatype) override { - mOutput->setDatatype(datatype); - - // FIXME: temporary workaround - mInput->setDatatype(datatype); - } - - inline IOIndex_t nbInputs() const noexcept override final { return 1; } - inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } - inline IOIndex_t nbOutputs() const noexcept override final { return 1; } static const std::vector<std::string> getInputsName(){ return {"data_input"}; } @@ -190,4 +135,4 @@ const char *const EnumStrings<Aidge::AvgPoolingAttr>::data[] = {"StrideDims", "KernelDims"}; } -#endif /* AIDGE_CORE_OPERATOR_AVGPOOLING_H_ */ +#endif /* AIDGE_CORE_OPERATOR_AVGPOOLING_H_ */ \ No newline at end of file diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index b1e3e34b0..62f2446f3 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -19,7 +19,7 @@ #include "aidge/data/Tensor.hpp" #include "aidge/graph/Node.hpp" -#include "aidge/operator/Operator.hpp" +#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/Registrar.hpp" @@ -29,17 +29,12 @@ namespace Aidge { enum class ConvAttr { StrideDims, DilationDims, InChannels, OutChannels, KernelDims }; template <DimIdx_t DIM> -class Conv_Op : public Operator, +class Conv_Op : public OperatorTensor, public Registrable<Conv_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Conv_Op<DIM> &)>, public StaticAttributes<ConvAttr, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t, DimSize_t, std::array<DimSize_t, DIM>> { -public: - // FIXME: change accessibility - std::array<std::shared_ptr<Tensor>, 3> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>(), - std::make_shared<Tensor>()}; - const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); - public: +public: static constexpr const char *Type = "Conv"; Conv_Op() = delete; @@ -54,13 +49,13 @@ public: const std::array<DimSize_t, DIM> &kernel_dims, const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) - : Operator(Type), + : OperatorTensor(Type, 1, 2, 1), Attributes_(attr<ConvAttr::StrideDims>(stride_dims), attr<ConvAttr::DilationDims>(dilation_dims), attr<ConvAttr::InChannels>(in_channels), attr<ConvAttr::OutChannels>(out_channels), attr<ConvAttr::KernelDims>(kernel_dims)) { - setDatatype(DataType::Float32); + setDataType(DataType::Float32); } /** @@ -68,12 +63,12 @@ public: * @param op Operator to copy. */ Conv_Op(const Conv_Op<DIM>& op) - : Operator(Type), + : OperatorTensor(Type, 1, 2, 1), Attributes_(op), mOutput(std::make_shared<Tensor>(*op.mOutput)) { // cpy-ctor - setDatatype(op.mOutput->dataType()); + setDataType(op.mOutput->dataType()); mImpl = op.mImpl ? Registrar<Conv_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; } @@ -98,13 +93,6 @@ public: // } - void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { - assert(inputIdx < 3 && "operators supports only 3 inputs"); - assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); - - mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); - } - void computeOutputDims() override final { if (!mInputs[0]->empty()) { std::array<DimSize_t, DIM + 2> outputDims = {}; @@ -125,37 +113,6 @@ public: } } - bool outputDimsForwarded() const override final { return !(mOutput->empty()); } - - - inline Tensor& input(const IOIndex_t inputIdx) const override final { - assert(inputIdx < 3 && "operators supports only 3 inputs"); - return *(mInputs[inputIdx].get()); } - inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } - - - inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { - assert(inputIdx < 3 && "Conv Operators supports only 3 inputs"); - return mInputs[inputIdx]; - } - inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { - assert((outputIdx == 0) && "Conv Operator has only 1 output"); - (void) outputIdx; // avoid unused warning - return mOutput; - } - - - std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { - assert(inputIdx < 3 && "operators supports only 3 inputs"); - return std::static_pointer_cast<Data>(mInputs[inputIdx]); - } - std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { - assert(outputIdx == 0 && "operator supports only 1 output"); - (void) outputIdx; // avoid unused warning - return std::static_pointer_cast<Data>(mOutput); - } - - void setBackend(const std::string &name) override { mImpl = Registrar<Conv_Op<DIM>>::create(name)(*this); mOutput->setBackend(name); @@ -165,18 +122,6 @@ public: mInputs[2]->setBackend(name); } - void setDatatype(const DataType &datatype) override { - mOutput->setDatatype(datatype); - - // FIXME: temporary workaround - mInputs[0]->setDatatype(datatype); - mInputs[1]->setDatatype(datatype); - mInputs[2]->setDatatype(datatype); - } - - inline IOIndex_t nbInputs() const noexcept override final { return 3; } - inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } - inline IOIndex_t nbOutputs() const noexcept override final { return 1; } static const std::vector<std::string> getInputsName(){ return {"data_input", "weight", "bias"}; } diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp index 166754cc9..b76bf3336 100644 --- a/src/backend/OperatorImpl.cpp +++ b/src/backend/OperatorImpl.cpp @@ -25,25 +25,25 @@ Aidge::OperatorImpl::OperatorImpl(const Operator& op): } Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { - assert(mOp.getInput(inputIdx) && "requires valid input"); + assert(mOp.getRawInput(inputIdx) && "requires valid input"); // Requires the whole tensor by default - return std::static_pointer_cast<Tensor>(mOp.getInput(inputIdx))->size(); + return std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->size(); } Aidge::NbElts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) const { - assert(mOp.getInput(inputIdx) && "requires valid input"); + assert(mOp.getRawInput(inputIdx) && "requires valid input"); // Protect the whole tensor by default - return std::static_pointer_cast<Tensor>(mOp.getInput(inputIdx))->size(); + return std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->size(); } Aidge::NbElts_t Aidge::OperatorImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx, const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const { - assert(mOp.getOutput(outputIdx) && "requires valid output"); + assert(mOp.getRawOutput(outputIdx) && "requires valid output"); // Requires the whole tensor by default, regardless of available data on inputs - return std::static_pointer_cast<Tensor>(mOp.getOutput(outputIdx))->size(); + return std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx))->size(); } Aidge::NbElts_t Aidge::OperatorImpl::getNbConsumedData(Aidge::IOIndex_t inputIdx) const { -- GitLab