From a95cddb10fad1c2d785041d38a6142d73d1e3df6 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Fri, 1 Mar 2024 09:54:25 +0100
Subject: [PATCH] Fixed missing producer consumer system reset

---
 include/aidge/backend/OperatorImpl.hpp         |  6 ++++++
 include/aidge/operator/Operator.hpp            |  2 ++
 include/aidge/scheduler/Scheduler.hpp          |  6 +-----
 python_binding/backend/pybind_OperatorImpl.cpp | 10 ++++++++++
 src/backend/OperatorImpl.cpp                   |  5 +++++
 src/operator/Operator.cpp                      |  4 ++++
 src/scheduler/Scheduler.cpp                    | 14 +++++++++++++-
 7 files changed, 41 insertions(+), 6 deletions(-)

diff --git a/include/aidge/backend/OperatorImpl.hpp b/include/aidge/backend/OperatorImpl.hpp
index 19f083750..8b5aba10d 100644
--- a/include/aidge/backend/OperatorImpl.hpp
+++ b/include/aidge/backend/OperatorImpl.hpp
@@ -63,6 +63,12 @@ public:
      */
     virtual void updateConsummerProducer();
 
+    /**
+     * @brief Reset the Consummer Producer system.
+     *
+     */
+    virtual void resetConsummerProducer();
+
     virtual ~OperatorImpl() = default;
 
 protected:
diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp
index 7cfe6b925..9a32a9b9d 100644
--- a/include/aidge/operator/Operator.hpp
+++ b/include/aidge/operator/Operator.hpp
@@ -153,6 +153,8 @@ public:
 
     virtual void updateConsummerProducer();
 
+    virtual void resetConsummerProducer();
+
     virtual void forward();
 
     virtual void backward();
diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp
index 06c3b6e9e..747785bf8 100644
--- a/include/aidge/scheduler/Scheduler.hpp
+++ b/include/aidge/scheduler/Scheduler.hpp
@@ -58,11 +58,7 @@ public:
     ~SequentialScheduler() = default;
 
     void generateScheduling(bool verbose = false);
-    inline void resetScheduling() {
-        mScheduling.clear();
-        mStaticSchedule.clear();
-        mStaticScheduleStep = 0;
-    }
+    void resetScheduling();
 
     /**
      * Generate the memory layout for the current static scheduling.
diff --git a/python_binding/backend/pybind_OperatorImpl.cpp b/python_binding/backend/pybind_OperatorImpl.cpp
index 346100690..a2a5e6b8b 100644
--- a/python_binding/backend/pybind_OperatorImpl.cpp
+++ b/python_binding/backend/pybind_OperatorImpl.cpp
@@ -102,6 +102,15 @@ public:
 
         );
     }
+    void resetConsummerProducer() override {
+        PYBIND11_OVERRIDE_NAME(
+            void,
+            OperatorImpl,
+            "reset_consummer_producer",
+            resetConsummerProducer,
+
+        );
+    }
 };
 
 void init_OperatorImpl(py::module& m){
@@ -116,6 +125,7 @@ void init_OperatorImpl(py::module& m){
     .def("get_nb_consumed_data", &OperatorImpl::getNbConsumedData)
     .def("get_nb_produced_data", &OperatorImpl::getNbProducedData)
     .def("update_consummer_producer", &OperatorImpl::updateConsummerProducer)
+    .def("reset_consummer_producer", &OperatorImpl::resetConsummerProducer)
     ;
 }
 }
diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp
index b76bf3336..1911da228 100644
--- a/src/backend/OperatorImpl.cpp
+++ b/src/backend/OperatorImpl.cpp
@@ -68,6 +68,11 @@ void Aidge::OperatorImpl::updateConsummerProducer(){
     }
 }
 
+void Aidge::OperatorImpl::resetConsummerProducer(){
+    std::fill(mNbConsumedData.begin(), mNbConsumedData.end(), 0);
+    std::fill(mNbProducedData.begin(), mNbProducedData.end(), 0);
+}
+
 void Aidge::OperatorImpl::forward() {
     AIDGE_THROW_OR_ABORT(std::runtime_error, "forward() not implemented");
 }
diff --git a/src/operator/Operator.cpp b/src/operator/Operator.cpp
index af6973362..289b2be90 100644
--- a/src/operator/Operator.cpp
+++ b/src/operator/Operator.cpp
@@ -59,6 +59,10 @@ void Aidge::Operator::updateConsummerProducer(){
     AIDGE_ASSERT(mImpl != nullptr, "updateConsummerProducer(): an implementation is required for {}!", type());
     mImpl->updateConsummerProducer();
 }
+void Aidge::Operator::resetConsummerProducer(){
+    AIDGE_ASSERT(mImpl != nullptr, "resetConsummerProducer(): an implementation is required for {}!", type());
+    mImpl->resetConsummerProducer();
+}
 
 void Aidge::Operator::runHooks() const {
     for (auto& hook : mHooks) {
diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp
index c82f58590..d45bc5f8e 100644
--- a/src/scheduler/Scheduler.cpp
+++ b/src/scheduler/Scheduler.cpp
@@ -301,6 +301,16 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
     }
 }
 
+void Aidge::SequentialScheduler::resetScheduling() {
+    for (auto node : mGraphView->getNodes()) {
+        node->getOperator()->resetConsummerProducer();
+    }
+
+    mStaticSchedule.clear();
+    mStaticScheduleStep = 0;
+    mScheduling.clear();
+}
+
 /**
  * This version is a simplified version without special handling of concatenation.
 */
@@ -423,7 +433,6 @@ void Aidge::SequentialScheduler::connectInputs(std::vector<std::shared_ptr<Aidge
 }
 
 
-// TODO: handle multiple inputs/outputs
 void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::vector<std::shared_ptr<Aidge::Tensor>> data) {
     
     // Collect all data input of the graph (that are producers)
@@ -461,6 +470,9 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::ve
     fmt::print("\n");
 
     ++mStaticScheduleStep;
+    if (mStaticScheduleStep == mStaticSchedule.size()) {
+        mStaticScheduleStep = 0;
+    }
 }
 
 void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const {
-- 
GitLab