diff --git a/include/aidge/operator/Add.hpp b/include/aidge/operator/Add.hpp index ceb058dbd30a2cf882261d35a1a3a64c6eacb76a..b5e37f9bc52d4a74dabf68022235feadc384748f 100644 --- a/include/aidge/operator/Add.hpp +++ b/include/aidge/operator/Add.hpp @@ -32,15 +32,13 @@ private: // FIXME: change accessibility std::vector<std::shared_ptr<Tensor>> mInputs; const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); - const IOIndex_t mNbInputs; public: static constexpr const char* Type = "Add"; Add_Op(const IOIndex_t nbIn) : Operator(Type), - mInputs(std::vector<std::shared_ptr<Tensor>>(nbIn, std::make_shared<Tensor>())), - mNbInputs(nbIn) + mInputs(std::vector<std::shared_ptr<Tensor>>(nbIn, std::make_shared<Tensor>())) { assert(nbIn > 0 && "Add should have at least one input"); setDatatype(DataType::Float32); @@ -52,12 +50,11 @@ public: */ Add_Op(const Add_Op& op) : Operator(Type), - mInputs(op.mInputs), - mNbInputs(op.mNbInputs), + mInputs(std::vector<std::shared_ptr<Tensor>>(op.nbInputs())), mOutput(std::make_shared<Tensor>(*op.mOutput)) { // cpy-ctor - assert(mNbInputs > 0 && "Add should have at least one input"); + assert(op.nbInputs() > 0 && "Add should have at least one input"); setDatatype(op.mOutput->dataType()); mImpl = op.mImpl ? Registrar<Add_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; } @@ -80,7 +77,7 @@ public: // } void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { - assert(static_cast<std::size_t>(inputIdx) < mNbInputs && "wrong inputIdx for Add operator."); + assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Add operator."); assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); @@ -90,10 +87,10 @@ public: if (!mInputs[0]->empty()) { const auto expectedDims = mInputs[0]->dims(); std::size_t nonEmptyInputTensor = 1; - for (; nonEmptyInputTensor < mNbInputs && (!mInputs[nonEmptyInputTensor]->empty()); ++nonEmptyInputTensor) { + for (; nonEmptyInputTensor < nbInputs() && (!mInputs[nonEmptyInputTensor]->empty()); ++nonEmptyInputTensor) { assert(expectedDims == mInputs[nonEmptyInputTensor]->dims()); } - if (nonEmptyInputTensor == mNbInputs) { + if (nonEmptyInputTensor == nbInputs()) { mOutput->resize(expectedDims); } } @@ -101,8 +98,8 @@ public: bool outputDimsForwarded() const override final { std::size_t forwarded = 0; - for (; forwarded < mNbInputs && (!mInputs[forwarded]->empty()); ++forwarded) {} - return ((forwarded==mNbInputs) && !(mOutput->empty())); + for (; forwarded < nbInputs() && (!mInputs[forwarded]->empty()); ++forwarded) {} + return ((forwarded==nbInputs()) && !(mOutput->empty())); } // void checkDims() const override final { @@ -112,13 +109,13 @@ public: // } // } inline Tensor& input(const IOIndex_t inputIdx) const override final { - assert(static_cast<std::size_t>(inputIdx) < mNbInputs && "wrong inputIdx for Add operator."); + assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Add operator."); return *(mInputs[inputIdx].get()); } inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { - assert(static_cast<std::size_t>(inputIdx) < mNbInputs && "wrong inputIdx for Add operator."); + assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Add operator."); return mInputs[inputIdx]; } inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { @@ -128,7 +125,7 @@ public: } std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { - assert(static_cast<std::size_t>(inputIdx) < mNbInputs && "wrong inputIdx for Add operator."); + assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Add operator."); return std::static_pointer_cast<Data>(mInputs[inputIdx]); } std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { @@ -143,7 +140,7 @@ public: mOutput->setBackend(name); // FIXME: temporary workaround - for (std::size_t i = 0; i < mNbInputs; ++i) { + for (std::size_t i = 0; i < nbInputs(); ++i) { mInputs[i]->setBackend(name); } } @@ -152,13 +149,13 @@ public: mOutput->setDatatype(datatype); // FIXME: temporary workaround - for (std::size_t i = 0; i < mNbInputs; ++i) { + for (std::size_t i = 0; i < nbInputs(); ++i) { mInputs[i]->setDatatype(datatype); } } - inline IOIndex_t nbInputs() const noexcept override final { return mNbInputs; } - inline IOIndex_t nbDataInputs() const noexcept override final { return mNbInputs; } + inline IOIndex_t nbInputs() const noexcept override final { return mInputs.size(); } + inline IOIndex_t nbDataInputs() const noexcept override final { return mInputs.size(); } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/operator/Concat.hpp b/include/aidge/operator/Concat.hpp index 5bf113f125b4690fd9f1d60ef4814f21a5d89547..31b99370d12a16020ea7c6f9d35c08d9f616f10f 100644 --- a/include/aidge/operator/Concat.hpp +++ b/include/aidge/operator/Concat.hpp @@ -26,11 +26,11 @@ #include "aidge/utils/Types.h" namespace Aidge { -enum class ConcatAttr { NbInputs, Axis }; +enum class ConcatAttr { Axis }; class Concat_Op : public Operator, public Registrable<Concat_Op, std::string, std::unique_ptr<OperatorImpl>(const Concat_Op&)>, - public StaticAttributes<ConcatAttr, IOIndex_t, DimSize_t> { + public StaticAttributes<ConcatAttr, DimSize_t> { private: // FIXME: change accessibility std::vector<std::shared_ptr<Tensor>> mInputs; @@ -39,15 +39,14 @@ private: public: static constexpr const char* Type = "Concat"; - using Attributes_ = StaticAttributes<ConcatAttr, IOIndex_t, DimSize_t>; + using Attributes_ = StaticAttributes<ConcatAttr, DimSize_t>; template <ConcatAttr e> using attr = typename Attributes_::template attr<e>; Concat_Op(const IOIndex_t nbIn, const DimSize_t axis) : Operator(Type), mInputs(std::vector<std::shared_ptr<Tensor>>(nbIn, std::make_shared<Tensor>())), - Attributes_(attr<ConcatAttr::NbInputs>(nbIn), - attr<ConcatAttr::Axis>(axis)) + Attributes_(attr<ConcatAttr::Axis>(axis)) { assert(nbIn > 0 && "Concat should have at least one input"); setDatatype(DataType::Float32); @@ -60,11 +59,11 @@ public: Concat_Op(const Concat_Op& op) : Operator(Type), Attributes_(op), - mInputs(std::vector<std::shared_ptr<Tensor>>(op.getAttr<ConcatAttr::NbInputs>())), + mInputs(std::vector<std::shared_ptr<Tensor>>(op.nbInputs(), std::make_shared<Tensor>())), mOutput(std::make_shared<Tensor>(*op.mOutput)) { // cpy-ctor - assert(op.getAttr<ConcatAttr::NbInputs>() > 0 && "Concat should have at least one input"); + assert(op.nbInputs() > 0 && "Concat should have at least one input"); setDatatype(op.mOutput->dataType()); mImpl = op.mImpl ? Registrar<Concat_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; } @@ -87,7 +86,7 @@ public: // } void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { - assert(static_cast<std::size_t>(inputIdx) < getAttr<ConcatAttr::NbInputs>() && "wrong inputIdx for Concat operator."); + assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Concat operator."); assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); @@ -103,7 +102,7 @@ public: if (computable) { auto outputDims = mInputs[0]->dims(); - for (std::size_t i = 1; i < getAttr<ConcatAttr::NbInputs>(); ++i) { + for (std::size_t i = 1; i < nbInputs(); ++i) { outputDims[getAttr<ConcatAttr::Axis>()] += mInputs[i]->dims()[getAttr<ConcatAttr::Axis>()]; } mOutput->resize(outputDims); @@ -121,13 +120,13 @@ public: // } // } inline Tensor& input(const IOIndex_t inputIdx) const override final { - assert(static_cast<std::size_t>(inputIdx) < getAttr<ConcatAttr::NbInputs>() && "wrong inputIdx for Concat operator."); + assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Concat operator."); return *(mInputs[inputIdx].get()); } inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { - assert(static_cast<std::size_t>(inputIdx) < getAttr<ConcatAttr::NbInputs>() && "wrong inputIdx for Concat operator."); + assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Concat operator."); return mInputs[inputIdx]; } inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { @@ -137,7 +136,7 @@ public: } std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { - assert(static_cast<std::size_t>(inputIdx) < getAttr<ConcatAttr::NbInputs>() && "wrong inputIdx for Concat operator."); + assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Concat operator."); return std::static_pointer_cast<Data>(mInputs[inputIdx]); } std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { @@ -152,7 +151,7 @@ public: mOutput->setBackend(name); // FIXME: temporary workaround - for (std::size_t i = 0; i < getAttr<ConcatAttr::NbInputs>(); ++i) { + for (std::size_t i = 0; i < nbInputs(); ++i) { mInputs[i]->setBackend(name); } } @@ -161,13 +160,13 @@ public: mOutput->setDatatype(datatype); // FIXME: temporary workaround - for (std::size_t i = 0; i < getAttr<ConcatAttr::NbInputs>(); ++i) { + for (std::size_t i = 0; i < nbInputs(); ++i) { mInputs[i]->setDatatype(datatype); } } - inline IOIndex_t nbInputs() const noexcept override final { return getAttr<ConcatAttr::NbInputs>(); } - inline IOIndex_t nbDataInputs() const noexcept override final { return getAttr<ConcatAttr::NbInputs>(); } + inline IOIndex_t nbInputs() const noexcept override final { return mInputs.size(); } + inline IOIndex_t nbDataInputs() const noexcept override final { return mInputs.size(); } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } static const std::vector<std::string> getInputsName(){ @@ -186,7 +185,6 @@ inline std::shared_ptr<Node> Concat(const IOIndex_t nbIn, const DimIdx_t axis = namespace { template <> const char* const EnumStrings<Aidge::ConcatAttr>::data[] = { - "NbInputs", "Axis" }; }