From b033e7aa5c1061a6fbae63bb888528b7425875fb Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Tue, 12 Mar 2024 15:27:44 +0100 Subject: [PATCH] Fixed bugs in parallel scheduler --- include/aidge/scheduler/Scheduler.hpp | 1 + include/aidge/scheduler/ThreadPool.hpp | 4 +- src/scheduler/Scheduler.cpp | 140 ++++++++++++++++++------- src/scheduler/ThreadPool.cpp | 4 +- 4 files changed, 107 insertions(+), 42 deletions(-) diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 8155bb872..df5ca3d1d 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -41,6 +41,7 @@ protected: size_t early; size_t late; std::vector<std::shared_ptr<StaticSchedulingElement>> earlierThan; + std::vector<std::shared_ptr<StaticSchedulingElement>> laterThan; }; struct SchedulingElement { diff --git a/include/aidge/scheduler/ThreadPool.hpp b/include/aidge/scheduler/ThreadPool.hpp index 22d0c3f38..5f2d9192d 100644 --- a/include/aidge/scheduler/ThreadPool.hpp +++ b/include/aidge/scheduler/ThreadPool.hpp @@ -23,10 +23,10 @@ namespace Aidge { class ThreadPool { public: - void start(size_t nbThreads = std::thread::hardware_concurrency()); + ThreadPool(size_t nbThreads = std::thread::hardware_concurrency()); void queueJob(const std::function<void()>& job); - void stop(); bool busy(); + virtual ~ThreadPool(); private: void threadLoop(); diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 040752875..11715901c 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -310,11 +310,19 @@ void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::sh const auto itNode = std::find_if(schedule.rend() - elt, schedule.rend(), [node](const auto& v) { return (v->node == node); }); - // Node can be run the earliest just after it was run the last time! + // Node can be run the earliest just after its childs were run the last time! size_t early = 0; if (itNode != schedule.rend()) { - early = (*itNode)->early + 1; - (*itNode)->earlierThan.push_back(schedule[elt]); + for (const auto& child : node->getChildren()) { + // Find child node next scheduled position + const auto it = std::find_if(schedule.rend() - elt, itNode, + [child](const auto& v) { return (v->node == child); }); + AIDGE_INTERNAL_ASSERT(it != schedule.rend()); + + const size_t step = std::distance(schedule.begin(), it.base()) - 1; + early = std::max(early, schedule[step]->early + 1); + schedule[step]->earlierThan.push_back(schedule[elt]); + } } // Node can be run the earliest just after its latest parent was run @@ -339,10 +347,19 @@ void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::sh const auto itNode = std::find_if(schedule.begin() + elt + 1, schedule.end(), [node](const auto& v) { return (v->node == node); }); - // Node can be run the latest just before it is run the next time! + // Node can be run the latest just before its parents are run the next time! size_t late = latest; if (itNode != schedule.end()) { - late = (*itNode)->late - 1; + for (const auto& parent : node->getParents()) { + // Find child node next scheduled position + const auto it = std::find_if(schedule.begin() + elt + 1, itNode, + [parent](const auto& v) { return (v->node == parent); }); + AIDGE_INTERNAL_ASSERT(it != schedule.end()); + + const size_t step = std::distance(schedule.begin(), it); + late = std::min(late, schedule[step]->late - 1); + schedule[step]->laterThan.push_back(schedule[elt]); + } } // Node can be run the latest just before its earliest child is run @@ -353,6 +370,7 @@ void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::sh if (it != schedule.end()) { const size_t step = std::distance(schedule.begin(), it); late = std::min(late, schedule[step]->late - 1); + schedule[step]->laterThan.push_back(schedule[elt]); } } @@ -733,42 +751,58 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, bool /*verbose*/, std:: // Thread pooling avoid the overhead of threads creation and deletion for // each node execution. ThreadPool pool; - pool.start(); size_t latest = 0; std::mutex schedulingMutex; - std::vector<int> required(staticSchedule.back()->late + 1, 0); - std::vector<std::atomic<int>> finished(staticSchedule.back()->late + 1); - std::fill(finished.begin(), finished.end(), 0); + std::map<std::shared_ptr<StaticSchedulingElement>, std::atomic<bool>> finished; while (!staticSchedule.empty()) { Log::debug("Step {}", latest); - // Run all nodes that must be run at latest + std::vector<std::shared_ptr<StaticSchedulingElement>> mustFinish; + + // Run all nodes that must be run at this step: latest (critical nodes) for (size_t i = 0; i < staticSchedule.size(); ) { auto runnable = staticSchedule[i]; if (runnable->late == latest) { - // Critical path - pool.queueJob([node = runnable->node, &finished = finished[latest], &schedulingMutex, this]() { + // Wait for potential preceding non-critical nodes to finish + while (true) { + bool ready = true; + for (auto elt : runnable->laterThan) { + ready = ready && finished.at(elt); + } + if (!ready) { + std::this_thread::yield(); + } + else { + break; + } + } + + // Add the critical node to the thread pool queue, to be run ASAP + finished[runnable] = false; + pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() { const auto tStart = std::chrono::high_resolution_clock::now(); node->forward(); const auto tEnd = std::chrono::high_resolution_clock::now(); - ++finished; + finished = true; { std::unique_lock<std::mutex> lock(schedulingMutex); mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd)); } }); staticSchedule.erase(staticSchedule.begin() + i); - ++required[latest]; + mustFinish.push_back(runnable); Log::debug(" run critical {}", namePtrTable.at(runnable->node)); + // Ensure the following nodes cannot start earlier than next step for (auto elt : runnable->earlierThan) { if (elt->early < latest + 1) { Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1); elt->early = latest + 1; + AIDGE_INTERNAL_ASSERT(elt->early <= elt->late); } } } @@ -784,46 +818,76 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, bool /*verbose*/, std:: } // If some threads are still available, run next early nodes - while (!pool.busy() && !staticSchedule.empty()) { - auto runnable = staticSchedule.front(); - if (runnable->early <= latest) { - pool.queueJob([node = runnable->node, &finished = finished[runnable->late], &schedulingMutex, this]() { - const auto tStart = std::chrono::high_resolution_clock::now(); - node->forward(); - const auto tEnd = std::chrono::high_resolution_clock::now(); - ++finished; - { - std::unique_lock<std::mutex> lock(schedulingMutex); - mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd)); - } - }); - staticSchedule.pop_front(); - ++required[runnable->late]; + // These nodes are non-critical, meaning they can still be run at least + // in the next step + for (size_t i = 0; i < staticSchedule.size(); ) { + auto runnable = staticSchedule[i]; + if (!pool.busy() && runnable->early <= latest) { + // Check that potential preceding non-critical nodes are finished + bool ready = true; + for (auto elt : runnable->laterThan) { + ready = ready && finished.at(elt); + } - Log::debug(" run {}", namePtrTable.at(runnable->node)); + if (ready) { + // All preceding nodes have finished, this node can be run. + // Add the node to the thread pool queue, to be run ASAP + finished[runnable] = false; + pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() { + const auto tStart = std::chrono::high_resolution_clock::now(); + node->forward(); + const auto tEnd = std::chrono::high_resolution_clock::now(); + finished = true; + { + std::unique_lock<std::mutex> lock(schedulingMutex); + mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd)); + } + }); + staticSchedule.erase(staticSchedule.begin() + i); - for (auto elt : runnable->earlierThan) { - if (elt->early < latest + 1) { - Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1); - elt->early = latest + 1; + Log::debug(" run {}", namePtrTable.at(runnable->node)); + + // Ensure the following nodes cannot start earlier than next step + for (auto elt : runnable->earlierThan) { + if (elt->early < latest + 1) { + Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1); + elt->early = latest + 1; + AIDGE_INTERNAL_ASSERT(elt->early <= elt->late); + } } } + else { + // The node cannot be run yet, because preceding nodes are + // still running, move to the next one in schedule + ++i; + } } else { + // Thread pool is already full or no more node can be run at + // this step (latest) break; } } // Wait for all nodes that must finish at latest to be finished - while (finished[latest] < required[latest]) { - std::this_thread::yield(); + // By scheduling construction, no other node can be started before all + // nodes at latest step are finished + while (true) { + bool ready = true; + for (auto elt : mustFinish) { + ready = ready && finished.at(elt); + } + if (!ready) { + std::this_thread::yield(); + } + else { + break; + } } ++latest; } - pool.stop(); - ++mStaticScheduleStep; if (mStaticScheduleStep == mStaticSchedule.size()) { mStaticScheduleStep = 0; diff --git a/src/scheduler/ThreadPool.cpp b/src/scheduler/ThreadPool.cpp index 52fab3104..e81ab7a76 100644 --- a/src/scheduler/ThreadPool.cpp +++ b/src/scheduler/ThreadPool.cpp @@ -11,7 +11,7 @@ #include "aidge/scheduler/ThreadPool.hpp" -void Aidge::ThreadPool::start(size_t nbThreads) { +Aidge::ThreadPool::ThreadPool(size_t nbThreads) { for (size_t i = 0; i < nbThreads; ++i) { mThreads.emplace_back(std::thread(&ThreadPool::threadLoop, this)); } @@ -52,7 +52,7 @@ bool Aidge::ThreadPool::busy() { return poolbusy; } -void Aidge::ThreadPool::stop() { +Aidge::ThreadPool::~ThreadPool() { { std::unique_lock<std::mutex> lock(mQueueMutex); mTerminate = true; -- GitLab