diff --git a/include/aidge/operator/Stack.hpp b/include/aidge/operator/Stack.hpp index cf85bbe72124283f37acef68f488fde2185447c3..57df72c39b8d1e00dee0ff5fcb9d205883d6e388 100644 --- a/include/aidge/operator/Stack.hpp +++ b/include/aidge/operator/Stack.hpp @@ -18,6 +18,7 @@ class StackProdConso : public ProdConso { Elts_t getRequiredMemory( const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const override final; + void resetConsummerProducer() override final; }; class StackOpImpl : public OperatorImpl { diff --git a/src/operator/Stack.cpp b/src/operator/Stack.cpp index bae984f09dbfdd51bfa298fb37b8f829d55a506c..752a8460c9143fb01fb8796631608c3e5192d6b5 100644 --- a/src/operator/Stack.cpp +++ b/src/operator/Stack.cpp @@ -24,7 +24,7 @@ namespace Aidge { // TODO: Check why getRequiredMemory is always called with empty vector as // inputSize -Aidge::Elts_t Aidge::StackProdConso::getRequiredMemory( +Elts_t StackProdConso::getRequiredMemory( const Aidge::IOIndex_t inputIdx, const std::vector<DimSize_t> &inputsSize) const { assert(mOp.getRawInput(inputIdx) && "requires valid input"); @@ -35,6 +35,13 @@ Aidge::Elts_t Aidge::StackProdConso::getRequiredMemory( return Elts_t::DataElts(op.getInput(inputIdx)->size()); } +void StackProdConso::resetConsummerProducer() { + ProdConso::updateConsummerProducer(); + + const StackOp &op = dynamic_cast<const StackOp &>(mOp); + op.forwardStep() = 0; +} + const std::string StackOp::s_type = "Stack"; void StackOpImpl::forward() { @@ -87,6 +94,7 @@ bool Aidge::StackOp::forwardDims(bool /*allowDataDependency*/) { return false; } + void StackOp::setBackend(const std::string &name, DeviceIdx_t device) { if (Registrar<StackOp>::exists({name})) { SET_IMPL_MACRO(StackOp, *this, name);