From 5503030cea24d63d1ad08ce99e250228ade31a99 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Sun, 25 Feb 2024 09:05:58 +0100 Subject: [PATCH] Added missing methods --- include/aidge/operator/MetaOperator.hpp | 2 ++ include/aidge/operator/Operator.hpp | 14 +++++++++--- src/operator/MetaOperator.cpp | 30 +++++++++++++++++++++++++ src/operator/Operator.cpp | 10 +++++++++ 4 files changed, 53 insertions(+), 3 deletions(-) diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 102f33a37..d172af49f 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -99,6 +99,8 @@ public: } NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override; + NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override; + NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const override; NbElts_t getNbConsumedData(IOIndex_t inputIdx) const override; NbElts_t getNbProducedData(IOIndex_t outputIdx) const override; diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 808450030..7cfe6b925 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -121,14 +121,22 @@ public: inline void setImpl(std::shared_ptr<OperatorImpl> impl) { mImpl = impl; } /** - * @brief Minimum amount of data from a specific input for one computation pass. + * @brief Minimum amount of data from a specific input required by the + * implementation to be run. + * * @param inputIdx Index of the input analysed. * @return NbElts_t */ virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const; + // Amount of input data that cannot be overwritten during the execution. + virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const; + + // Memory required at an output for a given input size. + virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const; + /** - * @brief Amount of data from a specific input actually used in one computation pass. + * @brief Total amount of consumed data from a specific input. * * @param inputIdx Index of the input analysed. * @return NbElts_t @@ -136,7 +144,7 @@ public: virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const; /** - * @brief Amount of data ready to be used on a specific output. + * @brief Total amount of produced data ready to be used on a specific output. * * @param outputIdx Index of the output analysed. * @return NbElts_t diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index 7a3780036..883185021 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -45,6 +45,36 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputI } } +Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredProtected(const IOIndex_t inputIdx) const { + if (mImpl) { + return mImpl->getNbRequiredProtected(inputIdx); + } + else { + const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; + if (inputOp.first) { + return inputOp.first->getOperator()->getNbRequiredProtected(inputOp.second); + } + else { + return 0; + } + } +} + +Aidge::NbElts_t Aidge::MetaOperator_Op::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const { + if (mImpl) { + return mImpl->getRequiredMemory(outputIdx, inputsSize); + } + else { + const auto& outputOp = mGraph->getOrderedOutputs()[outputIdx]; + if (outputOp.first) { + return outputOp.first->getOperator()->getRequiredMemory(outputOp.second, inputsSize); + } + else { + return 0; + } + } +} + Aidge::NbElts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) const { if (mImpl) { return mImpl->getNbConsumedData(inputIdx); diff --git a/src/operator/Operator.cpp b/src/operator/Operator.cpp index 6959e044c..af6973362 100644 --- a/src/operator/Operator.cpp +++ b/src/operator/Operator.cpp @@ -36,6 +36,16 @@ Aidge::NbElts_t Aidge::Operator::getNbRequiredData(const Aidge::IOIndex_t inputI return mImpl->getNbRequiredData(inputIdx); } +Aidge::NbElts_t Aidge::Operator::getNbRequiredProtected(const Aidge::IOIndex_t inputIdx) const { + AIDGE_ASSERT(mImpl != nullptr, "getNbRequiredProtected(): an implementation is required for {}!", type()); + return mImpl->getNbRequiredProtected(inputIdx); +} + +Aidge::NbElts_t Aidge::Operator::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const { + AIDGE_ASSERT(mImpl != nullptr, "getRequiredMemory(): an implementation is required for {}!", type()); + return mImpl->getRequiredMemory(outputIdx, inputsSize); +} + Aidge::NbElts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) const { AIDGE_ASSERT(mImpl != nullptr, "getNbConsumedData(): an implementation is required for {}!", type()); return mImpl->getNbConsumedData(inputIdx); -- GitLab