diff --git a/include/aidge/operator/Stack.hpp b/include/aidge/operator/Stack.hpp index 57df72c39b8d1e00dee0ff5fcb9d205883d6e388..9644620d71276c5e35fc9daaf634f4d4cdb28405 100644 --- a/include/aidge/operator/Stack.hpp +++ b/include/aidge/operator/Stack.hpp @@ -18,7 +18,7 @@ class StackProdConso : public ProdConso { Elts_t getRequiredMemory( const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const override final; - void resetConsummerProducer() override final; + void resetConsummerProducer() override; }; class StackOpImpl : public OperatorImpl { diff --git a/src/operator/Stack.cpp b/src/operator/Stack.cpp index 752a8460c9143fb01fb8796631608c3e5192d6b5..efe6296a351f69ef3a11d4e1bc04bd0b52d46a06 100644 --- a/src/operator/Stack.cpp +++ b/src/operator/Stack.cpp @@ -32,11 +32,15 @@ Elts_t StackProdConso::getRequiredMemory( const StackOp &op = dynamic_cast<const StackOp &>(mOp); // The produced data after one forward pass is simply the input size, // we do not produce the whole output tensor everytime. - return Elts_t::DataElts(op.getInput(inputIdx)->size()); + if (op.forwardStep() <= op.maxElements()) { + return Elts_t::DataElts(op.getInput(inputIdx)->size()); + } else { + return Elts_t::NoneElts(); + } } void StackProdConso::resetConsummerProducer() { - ProdConso::updateConsummerProducer(); + ProdConso::resetConsummerProducer(); const StackOp &op = dynamic_cast<const StackOp &>(mOp); op.forwardStep() = 0; @@ -94,7 +98,6 @@ bool Aidge::StackOp::forwardDims(bool /*allowDataDependency*/) { return false; } - void StackOp::setBackend(const std::string &name, DeviceIdx_t device) { if (Registrar<StackOp>::exists({name})) { SET_IMPL_MACRO(StackOp, *this, name);