Skip to content
Snippets Groups Projects
Commit 5503030c authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added missing methods

parent eeda15a5
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!77Support for recurrent networks
......@@ -99,6 +99,8 @@ public:
}
NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override;
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override;
NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const override;
NbElts_t getNbConsumedData(IOIndex_t inputIdx) const override;
NbElts_t getNbProducedData(IOIndex_t outputIdx) const override;
......
......@@ -121,14 +121,22 @@ public:
inline void setImpl(std::shared_ptr<OperatorImpl> impl) { mImpl = impl; }
/**
* @brief Minimum amount of data from a specific input for one computation pass.
* @brief Minimum amount of data from a specific input required by the
* implementation to be run.
*
* @param inputIdx Index of the input analysed.
* @return NbElts_t
*/
virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const;
// Amount of input data that cannot be overwritten during the execution.
virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const;
// Memory required at an output for a given input size.
virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const;
/**
* @brief Amount of data from a specific input actually used in one computation pass.
* @brief Total amount of consumed data from a specific input.
*
* @param inputIdx Index of the input analysed.
* @return NbElts_t
......@@ -136,7 +144,7 @@ public:
virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const;
/**
* @brief Amount of data ready to be used on a specific output.
* @brief Total amount of produced data ready to be used on a specific output.
*
* @param outputIdx Index of the output analysed.
* @return NbElts_t
......
......@@ -45,6 +45,36 @@ Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputI
}
}
Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredProtected(const IOIndex_t inputIdx) const {
if (mImpl) {
return mImpl->getNbRequiredProtected(inputIdx);
}
else {
const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
if (inputOp.first) {
return inputOp.first->getOperator()->getNbRequiredProtected(inputOp.second);
}
else {
return 0;
}
}
}
Aidge::NbElts_t Aidge::MetaOperator_Op::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const {
if (mImpl) {
return mImpl->getRequiredMemory(outputIdx, inputsSize);
}
else {
const auto& outputOp = mGraph->getOrderedOutputs()[outputIdx];
if (outputOp.first) {
return outputOp.first->getOperator()->getRequiredMemory(outputOp.second, inputsSize);
}
else {
return 0;
}
}
}
Aidge::NbElts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) const {
if (mImpl) {
return mImpl->getNbConsumedData(inputIdx);
......
......@@ -36,6 +36,16 @@ Aidge::NbElts_t Aidge::Operator::getNbRequiredData(const Aidge::IOIndex_t inputI
return mImpl->getNbRequiredData(inputIdx);
}
Aidge::NbElts_t Aidge::Operator::getNbRequiredProtected(const Aidge::IOIndex_t inputIdx) const {
AIDGE_ASSERT(mImpl != nullptr, "getNbRequiredProtected(): an implementation is required for {}!", type());
return mImpl->getNbRequiredProtected(inputIdx);
}
Aidge::NbElts_t Aidge::Operator::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const {
AIDGE_ASSERT(mImpl != nullptr, "getRequiredMemory(): an implementation is required for {}!", type());
return mImpl->getRequiredMemory(outputIdx, inputsSize);
}
Aidge::NbElts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
AIDGE_ASSERT(mImpl != nullptr, "getNbConsumedData(): an implementation is required for {}!", type());
return mImpl->getNbConsumedData(inputIdx);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment