From 942303831a4410550d0f9c68e98ac613118533f9 Mon Sep 17 00:00:00 2001
From: Jerome Hue <jerome.hue@cea.fr>
Date: Mon, 25 Nov 2024 18:12:17 +0100
Subject: [PATCH] override updateConsumerProducer method in Stack operator

---
 include/aidge/operator/Stack.hpp |  1 +
 src/operator/Stack.cpp           | 10 +++++++++-
 2 files changed, 10 insertions(+), 1 deletion(-)

diff --git a/include/aidge/operator/Stack.hpp b/include/aidge/operator/Stack.hpp
index cf85bbe72..57df72c39 100644
--- a/include/aidge/operator/Stack.hpp
+++ b/include/aidge/operator/Stack.hpp
@@ -18,6 +18,7 @@ class StackProdConso : public ProdConso {
     Elts_t getRequiredMemory(
         const IOIndex_t outputIdx,
         const std::vector<DimSize_t> &inputsSize) const override final;
+    void resetConsummerProducer() override final;
 };
 
 class StackOpImpl : public OperatorImpl {
diff --git a/src/operator/Stack.cpp b/src/operator/Stack.cpp
index bae984f09..752a8460c 100644
--- a/src/operator/Stack.cpp
+++ b/src/operator/Stack.cpp
@@ -24,7 +24,7 @@ namespace Aidge {
 
 // TODO: Check why getRequiredMemory is always called with empty vector as
 // inputSize
-Aidge::Elts_t Aidge::StackProdConso::getRequiredMemory(
+Elts_t StackProdConso::getRequiredMemory(
     const Aidge::IOIndex_t inputIdx,
     const std::vector<DimSize_t> &inputsSize) const {
     assert(mOp.getRawInput(inputIdx) && "requires valid input");
@@ -35,6 +35,13 @@ Aidge::Elts_t Aidge::StackProdConso::getRequiredMemory(
     return Elts_t::DataElts(op.getInput(inputIdx)->size());
 }
 
+void StackProdConso::resetConsummerProducer() {
+    ProdConso::updateConsummerProducer();
+
+    const StackOp &op = dynamic_cast<const StackOp &>(mOp);
+    op.forwardStep() = 0;
+}
+
 const std::string StackOp::s_type = "Stack";
 
 void StackOpImpl::forward() {
@@ -87,6 +94,7 @@ bool Aidge::StackOp::forwardDims(bool /*allowDataDependency*/) {
     return false;
 }
 
+
 void StackOp::setBackend(const std::string &name, DeviceIdx_t device) {
     if (Registrar<StackOp>::exists({name})) {
         SET_IMPL_MACRO(StackOp, *this, name);
-- 
GitLab