From b033e7aa5c1061a6fbae63bb888528b7425875fb Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Tue, 12 Mar 2024 15:27:44 +0100
Subject: [PATCH] Fixed bugs in parallel scheduler

---
 include/aidge/scheduler/Scheduler.hpp  |   1 +
 include/aidge/scheduler/ThreadPool.hpp |   4 +-
 src/scheduler/Scheduler.cpp            | 140 ++++++++++++++++++-------
 src/scheduler/ThreadPool.cpp           |   4 +-
 4 files changed, 107 insertions(+), 42 deletions(-)

diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp
index 8155bb872..df5ca3d1d 100644
--- a/include/aidge/scheduler/Scheduler.hpp
+++ b/include/aidge/scheduler/Scheduler.hpp
@@ -41,6 +41,7 @@ protected:
         size_t early;
         size_t late;
         std::vector<std::shared_ptr<StaticSchedulingElement>> earlierThan;
+        std::vector<std::shared_ptr<StaticSchedulingElement>> laterThan;
     };
 
     struct SchedulingElement {
diff --git a/include/aidge/scheduler/ThreadPool.hpp b/include/aidge/scheduler/ThreadPool.hpp
index 22d0c3f38..5f2d9192d 100644
--- a/include/aidge/scheduler/ThreadPool.hpp
+++ b/include/aidge/scheduler/ThreadPool.hpp
@@ -23,10 +23,10 @@
 namespace Aidge {
 class ThreadPool {
 public:
-    void start(size_t nbThreads = std::thread::hardware_concurrency());
+    ThreadPool(size_t nbThreads = std::thread::hardware_concurrency());
     void queueJob(const std::function<void()>& job);
-    void stop();
     bool busy();
+    virtual ~ThreadPool();
 
 private:
     void threadLoop();
diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp
index 040752875..11715901c 100644
--- a/src/scheduler/Scheduler.cpp
+++ b/src/scheduler/Scheduler.cpp
@@ -310,11 +310,19 @@ void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::sh
         const auto itNode = std::find_if(schedule.rend() - elt, schedule.rend(),
             [node](const auto& v) { return (v->node == node); });
 
-        // Node can be run the earliest just after it was run the last time!
+        // Node can be run the earliest just after its childs were run the last time!
         size_t early = 0;
         if (itNode != schedule.rend()) {
-            early = (*itNode)->early + 1;
-            (*itNode)->earlierThan.push_back(schedule[elt]);
+            for (const auto& child : node->getChildren()) {
+                // Find child node next scheduled position
+                const auto it = std::find_if(schedule.rend() - elt, itNode,
+                    [child](const auto& v) { return (v->node == child); });
+                AIDGE_INTERNAL_ASSERT(it != schedule.rend());
+
+                const size_t step = std::distance(schedule.begin(), it.base()) - 1;
+                early = std::max(early, schedule[step]->early + 1);
+                schedule[step]->earlierThan.push_back(schedule[elt]);
+            }
         }
 
         // Node can be run the earliest just after its latest parent was run
@@ -339,10 +347,19 @@ void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::sh
         const auto itNode = std::find_if(schedule.begin() + elt + 1, schedule.end(),
             [node](const auto& v) { return (v->node == node); });
 
-        // Node can be run the latest just before it is run the next time!
+        // Node can be run the latest just before its parents are run the next time!
         size_t late = latest;
         if (itNode != schedule.end()) {
-            late = (*itNode)->late - 1;
+            for (const auto& parent : node->getParents()) {
+                // Find child node next scheduled position
+                const auto it = std::find_if(schedule.begin() + elt + 1, itNode,
+                    [parent](const auto& v) { return (v->node == parent); });
+                AIDGE_INTERNAL_ASSERT(it != schedule.end());
+
+                const size_t step = std::distance(schedule.begin(), it);
+                late = std::min(late, schedule[step]->late - 1);
+                schedule[step]->laterThan.push_back(schedule[elt]);
+            }
         }
 
         // Node can be run the latest just before its earliest child is run
@@ -353,6 +370,7 @@ void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::sh
             if (it != schedule.end()) {
                 const size_t step = std::distance(schedule.begin(), it);
                 late = std::min(late, schedule[step]->late - 1);
+                schedule[step]->laterThan.push_back(schedule[elt]);
             }
         }
 
@@ -733,42 +751,58 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, bool /*verbose*/, std::
     // Thread pooling avoid the overhead of threads creation and deletion for
     // each node execution.
     ThreadPool pool;
-    pool.start();
 
     size_t latest = 0;
     std::mutex schedulingMutex;
-    std::vector<int> required(staticSchedule.back()->late + 1, 0);
-    std::vector<std::atomic<int>> finished(staticSchedule.back()->late + 1);
-    std::fill(finished.begin(), finished.end(), 0);
+    std::map<std::shared_ptr<StaticSchedulingElement>, std::atomic<bool>> finished;
 
     while (!staticSchedule.empty()) {
         Log::debug("Step {}", latest);
 
-        // Run all nodes that must be run at latest
+        std::vector<std::shared_ptr<StaticSchedulingElement>> mustFinish;
+
+        // Run all nodes that must be run at this step: latest (critical nodes)
         for (size_t i = 0; i < staticSchedule.size(); ) {
             auto runnable = staticSchedule[i];
 
             if (runnable->late == latest) {
-                // Critical path
-                pool.queueJob([node = runnable->node, &finished = finished[latest], &schedulingMutex, this]() {
+                // Wait for potential preceding non-critical nodes to finish
+                while (true) {
+                    bool ready = true;
+                    for (auto elt : runnable->laterThan) {
+                        ready = ready && finished.at(elt);
+                    }
+                    if (!ready) {
+                        std::this_thread::yield();
+                    }
+                    else {
+                        break;
+                    }
+                }
+
+                // Add the critical node to the thread pool queue, to be run ASAP
+                finished[runnable] = false;
+                pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() {
                     const auto tStart = std::chrono::high_resolution_clock::now();
                     node->forward();
                     const auto tEnd = std::chrono::high_resolution_clock::now();
-                    ++finished;
+                    finished = true;
                     {
                         std::unique_lock<std::mutex> lock(schedulingMutex);
                         mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd));
                     }
                 });
                 staticSchedule.erase(staticSchedule.begin() + i);
-                ++required[latest];
+                mustFinish.push_back(runnable);
 
                 Log::debug("  run critical {}", namePtrTable.at(runnable->node));
 
+                // Ensure the following nodes cannot start earlier than next step
                 for (auto elt : runnable->earlierThan) {
                     if (elt->early < latest + 1) {
                         Log::debug("    {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1);
                         elt->early = latest + 1;
+                        AIDGE_INTERNAL_ASSERT(elt->early <= elt->late);
                     }
                 }
             }
@@ -784,46 +818,76 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, bool /*verbose*/, std::
         }
 
         // If some threads are still available, run next early nodes
-        while (!pool.busy() && !staticSchedule.empty()) {
-            auto runnable = staticSchedule.front();
-            if (runnable->early <= latest) {
-                pool.queueJob([node = runnable->node, &finished = finished[runnable->late], &schedulingMutex, this]() {
-                    const auto tStart = std::chrono::high_resolution_clock::now();
-                    node->forward();
-                    const auto tEnd = std::chrono::high_resolution_clock::now();
-                    ++finished;
-                    {
-                        std::unique_lock<std::mutex> lock(schedulingMutex);
-                        mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd));
-                    }
-                });
-                staticSchedule.pop_front();
-                ++required[runnable->late];
+        // These nodes are non-critical, meaning they can still be run at least
+        // in the next step
+        for (size_t i = 0; i < staticSchedule.size(); ) {
+            auto runnable = staticSchedule[i];
+            if (!pool.busy() && runnable->early <= latest) {
+                // Check that potential preceding non-critical nodes are finished
+                bool ready = true;
+                for (auto elt : runnable->laterThan) {
+                    ready = ready && finished.at(elt);
+                }
 
-                Log::debug("  run {}", namePtrTable.at(runnable->node));
+                if (ready) {
+                    // All preceding nodes have finished, this node can be run.
+                    // Add the node to the thread pool queue, to be run ASAP
+                    finished[runnable] = false;
+                    pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() {
+                        const auto tStart = std::chrono::high_resolution_clock::now();
+                        node->forward();
+                        const auto tEnd = std::chrono::high_resolution_clock::now();
+                        finished = true;
+                        {
+                            std::unique_lock<std::mutex> lock(schedulingMutex);
+                            mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd));
+                        }
+                    });
+                    staticSchedule.erase(staticSchedule.begin() + i);
 
-                for (auto elt : runnable->earlierThan) {
-                    if (elt->early < latest + 1) {
-                        Log::debug("    {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1);
-                        elt->early = latest + 1;
+                    Log::debug("  run {}", namePtrTable.at(runnable->node));
+
+                    // Ensure the following nodes cannot start earlier than next step
+                    for (auto elt : runnable->earlierThan) {
+                        if (elt->early < latest + 1) {
+                            Log::debug("    {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1);
+                            elt->early = latest + 1;
+                            AIDGE_INTERNAL_ASSERT(elt->early <= elt->late);
+                        }
                     }
                 }
+                else {
+                    // The node cannot be run yet, because preceding nodes are
+                    // still running, move to the next one in schedule
+                    ++i;
+                }
             }
             else {
+                // Thread pool is already full or no more node can be run at
+                // this step (latest)
                 break;
             }
         }
 
         // Wait for all nodes that must finish at latest to be finished
-        while (finished[latest] < required[latest]) {
-            std::this_thread::yield();
+        // By scheduling construction, no other node can be started before all 
+        // nodes at latest step are finished
+        while (true) {
+            bool ready = true;
+            for (auto elt : mustFinish) {
+                ready = ready && finished.at(elt);
+            }
+            if (!ready) {
+                std::this_thread::yield();
+            }
+            else {
+                break;
+            }
         }
 
         ++latest;
     }
 
-    pool.stop();
-
     ++mStaticScheduleStep;
     if (mStaticScheduleStep == mStaticSchedule.size()) {
         mStaticScheduleStep = 0;
diff --git a/src/scheduler/ThreadPool.cpp b/src/scheduler/ThreadPool.cpp
index 52fab3104..e81ab7a76 100644
--- a/src/scheduler/ThreadPool.cpp
+++ b/src/scheduler/ThreadPool.cpp
@@ -11,7 +11,7 @@
 
 #include "aidge/scheduler/ThreadPool.hpp"
 
-void Aidge::ThreadPool::start(size_t nbThreads) {
+Aidge::ThreadPool::ThreadPool(size_t nbThreads) {
     for (size_t i = 0; i < nbThreads; ++i) {
         mThreads.emplace_back(std::thread(&ThreadPool::threadLoop, this));
     }
@@ -52,7 +52,7 @@ bool Aidge::ThreadPool::busy() {
     return poolbusy;
 }
 
-void Aidge::ThreadPool::stop() {
+Aidge::ThreadPool::~ThreadPool() {
     {
         std::unique_lock<std::mutex> lock(mQueueMutex);
         mTerminate = true;
-- 
GitLab