Skip to content
Snippets Groups Projects
Commit 5503030c authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added missing methods

parent eeda15a5
No related branches found
No related tags found
No related merge requests found
...@@ -99,6 +99,8 @@ public: ...@@ -99,6 +99,8 @@ public:
} }
NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override; NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override;
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override;
NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const override;
NbElts_t getNbConsumedData(IOIndex_t inputIdx) const override; NbElts_t getNbConsumedData(IOIndex_t inputIdx) const override;
NbElts_t getNbProducedData(IOIndex_t outputIdx) const override; NbElts_t getNbProducedData(IOIndex_t outputIdx) const override;
......
...@@ -121,14 +121,22 @@ public: ...@@ -121,14 +121,22 @@ public:
inline void setImpl(std::shared_ptr<OperatorImpl> impl) { mImpl = impl; } inline void setImpl(std::shared_ptr<OperatorImpl> impl) { mImpl = impl; }
/** /**
* @brief Minimum amount of data from a specific input for one computation pass. * @brief Minimum amount of data from a specific input required by the
* implementation to be run.
*
* @param inputIdx Index of the input analysed. * @param inputIdx Index of the input analysed.
* @return NbElts_t * @return NbElts_t
*/ */
virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const; virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const;
// Amount of input data that cannot be overwritten during the execution.
virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const;
// Memory required at an output for a given input size.
virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const;
/** /**
* @brief Amount of data from a specific input actually used in one computation pass. * @brief Total amount of consumed data from a specific input.
* *
* @param inputIdx Index of the input analysed. * @param inputIdx Index of the input analysed.
* @return NbElts_t * @return NbElts_t
...@@ -136,7 +144,7 @@ public: ...@@ -136,7 +144,7 @@ public:
virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const; virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const;
/** /**
* @brief Amount of data ready to be used on a specific output. * @brief Total amount of produced data ready to be used on a specific output.
* *
* @param outputIdx Index of the output analysed. * @param outputIdx Index of the output analysed.
* @return NbElts_t * @return NbElts_t
......
...@@ -45,6 +45,36 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputI ...@@ -45,6 +45,36 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputI
} }
} }
Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredProtected(const IOIndex_t inputIdx) const {
if (mImpl) {
return mImpl->getNbRequiredProtected(inputIdx);
}
else {
const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
if (inputOp.first) {
return inputOp.first->getOperator()->getNbRequiredProtected(inputOp.second);
}
else {
return 0;
}
}
}
Aidge::NbElts_t Aidge::MetaOperator_Op::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const {
if (mImpl) {
return mImpl->getRequiredMemory(outputIdx, inputsSize);
}
else {
const auto& outputOp = mGraph->getOrderedOutputs()[outputIdx];
if (outputOp.first) {
return outputOp.first->getOperator()->getRequiredMemory(outputOp.second, inputsSize);
}
else {
return 0;
}
}
}
Aidge::NbElts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) const { Aidge::NbElts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) const {
if (mImpl) { if (mImpl) {
return mImpl->getNbConsumedData(inputIdx); return mImpl->getNbConsumedData(inputIdx);
......
...@@ -36,6 +36,16 @@ Aidge::NbElts_t Aidge::Operator::getNbRequiredData(const Aidge::IOIndex_t inputI ...@@ -36,6 +36,16 @@ Aidge::NbElts_t Aidge::Operator::getNbRequiredData(const Aidge::IOIndex_t inputI
return mImpl->getNbRequiredData(inputIdx); return mImpl->getNbRequiredData(inputIdx);
} }
Aidge::NbElts_t Aidge::Operator::getNbRequiredProtected(const Aidge::IOIndex_t inputIdx) const {
AIDGE_ASSERT(mImpl != nullptr, "getNbRequiredProtected(): an implementation is required for {}!", type());
return mImpl->getNbRequiredProtected(inputIdx);
}
Aidge::NbElts_t Aidge::Operator::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const {
AIDGE_ASSERT(mImpl != nullptr, "getRequiredMemory(): an implementation is required for {}!", type());
return mImpl->getRequiredMemory(outputIdx, inputsSize);
}
Aidge::NbElts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) const { Aidge::NbElts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
AIDGE_ASSERT(mImpl != nullptr, "getNbConsumedData(): an implementation is required for {}!", type()); AIDGE_ASSERT(mImpl != nullptr, "getNbConsumedData(): an implementation is required for {}!", type());
return mImpl->getNbConsumedData(inputIdx); return mImpl->getNbConsumedData(inputIdx);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment