From e7f5fa0147a03810a10dd38828df6eb9ac0f2cc2 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Sat, 19 Oct 2024 20:38:12 +0000
Subject: [PATCH] Fix Scheduler::StaticSchedulingElement shared_ptr circular
 reference

- Change shared_ptr to raw ptr. It is possible without issue here as each pointer is stored and owned by Scheduler::mStaticSchedule and deleted with it
- Change Scheduler::resetScheduling() and Scheduler::~Scheduler() to delete raw pointers properly
---
 include/aidge/scheduler/Scheduler.hpp | 14 +++++------
 src/scheduler/ParallelScheduler.cpp   |  8 +++---
 src/scheduler/Scheduler.cpp           | 35 ++++++++++++++++++---------
 src/scheduler/SequentialScheduler.cpp |  2 +-
 4 files changed, 36 insertions(+), 23 deletions(-)

diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp
index 981920ea1..2d03f4e8b 100644
--- a/include/aidge/scheduler/Scheduler.hpp
+++ b/include/aidge/scheduler/Scheduler.hpp
@@ -61,8 +61,8 @@ protected:
         std::shared_ptr<Node> node; /** Scheduled `Node` */
         std::size_t early; /** Earliest possible execution time */
         std::size_t late; /** Latest possible execution time */
-        std::vector<std::shared_ptr<StaticSchedulingElement>> earlierThan; /** Nodes that must be executed earlier */
-        std::vector<std::shared_ptr<StaticSchedulingElement>> laterThan; /** Nodes that must be executed later */
+        std::vector<StaticSchedulingElement*> earlierThan; /** Nodes that must be executed earlier */
+        std::vector<StaticSchedulingElement*> laterThan; /** Nodes that must be executed later */
     };
 
     /**
@@ -110,7 +110,7 @@ public:
         // ctor
     };
 
-    virtual ~Scheduler() noexcept;
+    virtual ~Scheduler();
 
 public:
     /**
@@ -192,9 +192,9 @@ protected:
      * @brief Generate an initial base scheduling for the GraphView.
      * The scheduling is entirely sequential and garanteed to be valid w.r.t.
      * each node producer-consumer model.
-     * @return Vector of shared pointers to `StaticSchedulingElement` representing the base schedule.
+     * @return Vector of pointers to `StaticSchedulingElement` representing the base schedule.
     */
-    std::vector<std::shared_ptr<StaticSchedulingElement>> generateBaseScheduling() const;
+    std::vector<StaticSchedulingElement*> generateBaseScheduling() const;
 
     /**
      * @brief Calculates early and late execution times for each node in an initial base scheduling.
@@ -207,7 +207,7 @@ protected:
      *
      * @param schedule Vector of shared pointers to StaticSchedulingElements to be processed
      */
-    void generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const;
+    void generateEarlyLateScheduling(std::vector<StaticSchedulingElement*>& schedule) const;
 
 private:
     /**
@@ -227,7 +227,7 @@ protected:
     /** @brief List of SchedulingElement (i.e: Nodes with their computation time) */
     std::vector<SchedulingElement> mScheduling;
     /** @brief List of nodes ordered by their */
-    std::vector<std::vector<std::shared_ptr<StaticSchedulingElement>>> mStaticSchedule;
+    std::vector<std::vector<StaticSchedulingElement*>> mStaticSchedule;
     std::size_t mStaticScheduleStep = 0;
     mutable std::map<std::shared_ptr<Node>, PriorProducersConsumers> mPriorCache;
 };
diff --git a/src/scheduler/ParallelScheduler.cpp b/src/scheduler/ParallelScheduler.cpp
index 1d70646b7..2b9a1f5b6 100644
--- a/src/scheduler/ParallelScheduler.cpp
+++ b/src/scheduler/ParallelScheduler.cpp
@@ -48,7 +48,7 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std::
 
     // Sort static scheduling, the order will be the prefered threads scheduling
     // order for non critical nodes
-    std::deque<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end());
+    std::deque<StaticSchedulingElement*> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end());
     std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
         [](const auto& lhs, const auto& rhs) { return ((lhs->early < rhs->early) || (lhs->early == rhs->early && lhs->late < rhs->late)); });
 
@@ -59,12 +59,12 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std::
 
     size_t latest = 0;
     std::mutex schedulingMutex;
-    std::map<std::shared_ptr<StaticSchedulingElement>, std::atomic<bool>> finished;
+    std::map<StaticSchedulingElement*, std::atomic<bool>> finished;
 
     while (!staticSchedule.empty()) {
         Log::debug("Step {}", latest);
 
-        std::vector<std::shared_ptr<StaticSchedulingElement>> mustFinish;
+        std::vector<StaticSchedulingElement*> mustFinish;
 
         // Run all nodes that must be run at this step: latest (critical nodes)
         for (size_t i = 0; i < staticSchedule.size(); ) {
@@ -188,7 +188,7 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std::
         }
 
         // Wait for all nodes that must finish at latest to be finished
-        // By scheduling construction, no other node can be started before all 
+        // By scheduling construction, no other node can be started before all
         // nodes at latest step are finished
         while (true) {
             bool ready = true;
diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp
index 958b25432..34aea5ffd 100644
--- a/src/scheduler/Scheduler.cpp
+++ b/src/scheduler/Scheduler.cpp
@@ -37,7 +37,14 @@
 #include "aidge/utils/Types.h"
 
 
-Aidge::Scheduler::~Scheduler() noexcept = default;
+Aidge::Scheduler::~Scheduler() {
+    for (auto& staticScheduleVec : mStaticSchedule) {
+        for (auto& staticScheduleElt : staticScheduleVec) {
+            delete staticScheduleElt;
+        }
+        staticScheduleVec.clear();
+    }
+}
 Aidge::Scheduler::PriorProducersConsumers::PriorProducersConsumers() = default;
 Aidge::Scheduler::PriorProducersConsumers::PriorProducersConsumers(const PriorProducersConsumers&) = default;
 Aidge::Scheduler::PriorProducersConsumers::~PriorProducersConsumers() noexcept = default;
@@ -48,7 +55,7 @@ void Aidge::Scheduler::generateScheduling() {
     mStaticSchedule.push_back(schedule);
 }
 
-std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::Scheduler::generateBaseScheduling() const {
+std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::generateBaseScheduling() const {
 
     // 0) setup useful variables
     // map associating each node with string "name (type#rank)"
@@ -60,7 +67,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S
     // producers-consumers model!
     std::set<std::shared_ptr<Node>> stillConsumers;
 
-    std::vector<std::shared_ptr<StaticSchedulingElement>> schedule;
+    std::vector<StaticSchedulingElement*> schedule;
 
 
     // 1) Initialize consumers list: start from the output nodes and
@@ -124,7 +131,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S
         // Producers are special nodes that generate data on demand.
         for (const auto& requiredProducer : requiredProducers) {
             requiredProducer->getOperator()->updateConsummerProducer();
-            schedule.push_back(std::make_shared<StaticSchedulingElement>(requiredProducer));
+            schedule.push_back(new StaticSchedulingElement(requiredProducer));
         }
 
         // 5) Find runnable consumers.
@@ -178,7 +185,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S
         for (const auto& runnable : runnableConsumers) {
             Log::debug("Runnable: {}", namePtrTable.at(runnable));
             runnable->getOperator()->updateConsummerProducer();
-            schedule.push_back(std::make_shared<StaticSchedulingElement>(runnable));
+            schedule.push_back(new StaticSchedulingElement(runnable));
         }
 
         // 7) Update consumers list
@@ -310,7 +317,7 @@ void Aidge::Scheduler::summarizeConsumerState(const std::shared_ptr<Aidge::Node>
 }
 
 
-void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const {
+void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingElement*>& schedule) const {
     std::size_t latest = 0;
     // Calculate early (logical) start
     for (std::size_t elt = 0; elt < schedule.size(); ++elt) {
@@ -390,15 +397,20 @@ void Aidge::Scheduler::resetScheduling() {
     for (auto node : mGraphView->getNodes()) {
         node->getOperator()->resetConsummerProducer();
     }
-
+    for (auto& staticScheduleVec : mStaticSchedule) {
+        for (auto& staticScheduleElt : staticScheduleVec) {
+            delete staticScheduleElt;
+        }
+        staticScheduleVec.clear();
+    }
     mStaticSchedule.clear();
     mStaticScheduleStep = 0;
     mScheduling.clear();
 }
 
 /**
- * This version is a simplified version without special handling of concatenation.
-*/
+ * @warning This version is a simplified version without special handling of concatenation.
+ */
 Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wrapAroundBuffer) const {
     MemoryManager memManager;
 
@@ -669,8 +681,8 @@ Aidge::Elts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>&
     return Elts_t::NoneElts();
 }
 
-Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersConsumers(
-    const std::shared_ptr<Node>& node) const
+Aidge::Scheduler::PriorProducersConsumers
+Aidge::Scheduler::getPriorProducersConsumers(const std::shared_ptr<Node>& node) const
 {
     const auto priorCache = mPriorCache.find(node);
     if (priorCache != mPriorCache.end()) {
@@ -707,6 +719,7 @@ Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersCon
                     const auto& parentPrior = getPriorProducersConsumers(parent.first);
 
                     if (!parentPrior.isPrior) {
+                        // only happens in case of cyclic graphs
                         return PriorProducersConsumers(); // not scheduled
                     }
                     else {
diff --git a/src/scheduler/SequentialScheduler.cpp b/src/scheduler/SequentialScheduler.cpp
index 88b5e98bc..4e6e91f51 100644
--- a/src/scheduler/SequentialScheduler.cpp
+++ b/src/scheduler/SequentialScheduler.cpp
@@ -45,7 +45,7 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std
     }
 
     // Sort static scheduling according to the policy
-    std::vector<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end());
+    std::vector<StaticSchedulingElement*> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end());
 
     if (mSchedulingPolicy == SchedulingPolicy::AsSoonAsPossible) {
         std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
-- 
GitLab