diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 102f33a3714aaf1ee95556514aff4d19d8034271..d172af49fb85e054c02ac7d2c1ea1f0855b1264a 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -99,6 +99,8 @@ public: } NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override; + NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override; + NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const override; NbElts_t getNbConsumedData(IOIndex_t inputIdx) const override; NbElts_t getNbProducedData(IOIndex_t outputIdx) const override; diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 808450030bdfc176c9cbc435c76b4932586397b8..7cfe6b92521c3ef00528d1b5eff602d9f52b11fd 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -121,14 +121,22 @@ public: inline void setImpl(std::shared_ptr<OperatorImpl> impl) { mImpl = impl; } /** - * @brief Minimum amount of data from a specific input for one computation pass. + * @brief Minimum amount of data from a specific input required by the + * implementation to be run. + * * @param inputIdx Index of the input analysed. * @return NbElts_t */ virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const; + // Amount of input data that cannot be overwritten during the execution. + virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const; + + // Memory required at an output for a given input size. + virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const; + /** - * @brief Amount of data from a specific input actually used in one computation pass. + * @brief Total amount of consumed data from a specific input. * * @param inputIdx Index of the input analysed. * @return NbElts_t @@ -136,7 +144,7 @@ public: virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const; /** - * @brief Amount of data ready to be used on a specific output. + * @brief Total amount of produced data ready to be used on a specific output. * * @param outputIdx Index of the output analysed. * @return NbElts_t diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index 7a3780036f730d1f1d635a75f99d44a2f073d1bb..883185021b395b42e5c47ef0461ebc0614f14456 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -45,6 +45,36 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputI } } +Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredProtected(const IOIndex_t inputIdx) const { + if (mImpl) { + return mImpl->getNbRequiredProtected(inputIdx); + } + else { + const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; + if (inputOp.first) { + return inputOp.first->getOperator()->getNbRequiredProtected(inputOp.second); + } + else { + return 0; + } + } +} + +Aidge::NbElts_t Aidge::MetaOperator_Op::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const { + if (mImpl) { + return mImpl->getRequiredMemory(outputIdx, inputsSize); + } + else { + const auto& outputOp = mGraph->getOrderedOutputs()[outputIdx]; + if (outputOp.first) { + return outputOp.first->getOperator()->getRequiredMemory(outputOp.second, inputsSize); + } + else { + return 0; + } + } +} + Aidge::NbElts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) const { if (mImpl) { return mImpl->getNbConsumedData(inputIdx); diff --git a/src/operator/Operator.cpp b/src/operator/Operator.cpp index 6959e044c8d0b866ef00f99da87d7d701b817548..af697336284542edf38559f7b052e5211ddeb7d0 100644 --- a/src/operator/Operator.cpp +++ b/src/operator/Operator.cpp @@ -36,6 +36,16 @@ Aidge::NbElts_t Aidge::Operator::getNbRequiredData(const Aidge::IOIndex_t inputI return mImpl->getNbRequiredData(inputIdx); } +Aidge::NbElts_t Aidge::Operator::getNbRequiredProtected(const Aidge::IOIndex_t inputIdx) const { + AIDGE_ASSERT(mImpl != nullptr, "getNbRequiredProtected(): an implementation is required for {}!", type()); + return mImpl->getNbRequiredProtected(inputIdx); +} + +Aidge::NbElts_t Aidge::Operator::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const { + AIDGE_ASSERT(mImpl != nullptr, "getRequiredMemory(): an implementation is required for {}!", type()); + return mImpl->getRequiredMemory(outputIdx, inputsSize); +} + Aidge::NbElts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) const { AIDGE_ASSERT(mImpl != nullptr, "getNbConsumedData(): an implementation is required for {}!", type()); return mImpl->getNbConsumedData(inputIdx);