From 942303831a4410550d0f9c68e98ac613118533f9 Mon Sep 17 00:00:00 2001 From: Jerome Hue <jerome.hue@cea.fr> Date: Mon, 25 Nov 2024 18:12:17 +0100 Subject: [PATCH] override updateConsumerProducer method in Stack operator --- include/aidge/operator/Stack.hpp | 1 + src/operator/Stack.cpp | 10 +++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/include/aidge/operator/Stack.hpp b/include/aidge/operator/Stack.hpp index cf85bbe72..57df72c39 100644 --- a/include/aidge/operator/Stack.hpp +++ b/include/aidge/operator/Stack.hpp @@ -18,6 +18,7 @@ class StackProdConso : public ProdConso { Elts_t getRequiredMemory( const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const override final; + void resetConsummerProducer() override final; }; class StackOpImpl : public OperatorImpl { diff --git a/src/operator/Stack.cpp b/src/operator/Stack.cpp index bae984f09..752a8460c 100644 --- a/src/operator/Stack.cpp +++ b/src/operator/Stack.cpp @@ -24,7 +24,7 @@ namespace Aidge { // TODO: Check why getRequiredMemory is always called with empty vector as // inputSize -Aidge::Elts_t Aidge::StackProdConso::getRequiredMemory( +Elts_t StackProdConso::getRequiredMemory( const Aidge::IOIndex_t inputIdx, const std::vector<DimSize_t> &inputsSize) const { assert(mOp.getRawInput(inputIdx) && "requires valid input"); @@ -35,6 +35,13 @@ Aidge::Elts_t Aidge::StackProdConso::getRequiredMemory( return Elts_t::DataElts(op.getInput(inputIdx)->size()); } +void StackProdConso::resetConsummerProducer() { + ProdConso::updateConsummerProducer(); + + const StackOp &op = dynamic_cast<const StackOp &>(mOp); + op.forwardStep() = 0; +} + const std::string StackOp::s_type = "Stack"; void StackOpImpl::forward() { @@ -87,6 +94,7 @@ bool Aidge::StackOp::forwardDims(bool /*allowDataDependency*/) { return false; } + void StackOp::setBackend(const std::string &name, DeviceIdx_t device) { if (Registrar<StackOp>::exists({name})) { SET_IMPL_MACRO(StackOp, *this, name); -- GitLab