Skip to content
Snippets Groups Projects
Commit b033e7aa authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed bugs in parallel scheduler

parent 9b1d3208
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!94Improved scheduling
......@@ -41,6 +41,7 @@ protected:
size_t early;
size_t late;
std::vector<std::shared_ptr<StaticSchedulingElement>> earlierThan;
std::vector<std::shared_ptr<StaticSchedulingElement>> laterThan;
};
struct SchedulingElement {
......
......@@ -23,10 +23,10 @@
namespace Aidge {
class ThreadPool {
public:
void start(size_t nbThreads = std::thread::hardware_concurrency());
ThreadPool(size_t nbThreads = std::thread::hardware_concurrency());
void queueJob(const std::function<void()>& job);
void stop();
bool busy();
virtual ~ThreadPool();
private:
void threadLoop();
......
......@@ -310,11 +310,19 @@ void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::sh
const auto itNode = std::find_if(schedule.rend() - elt, schedule.rend(),
[node](const auto& v) { return (v->node == node); });
// Node can be run the earliest just after it was run the last time!
// Node can be run the earliest just after its childs were run the last time!
size_t early = 0;
if (itNode != schedule.rend()) {
early = (*itNode)->early + 1;
(*itNode)->earlierThan.push_back(schedule[elt]);
for (const auto& child : node->getChildren()) {
// Find child node next scheduled position
const auto it = std::find_if(schedule.rend() - elt, itNode,
[child](const auto& v) { return (v->node == child); });
AIDGE_INTERNAL_ASSERT(it != schedule.rend());
const size_t step = std::distance(schedule.begin(), it.base()) - 1;
early = std::max(early, schedule[step]->early + 1);
schedule[step]->earlierThan.push_back(schedule[elt]);
}
}
// Node can be run the earliest just after its latest parent was run
......@@ -339,10 +347,19 @@ void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::sh
const auto itNode = std::find_if(schedule.begin() + elt + 1, schedule.end(),
[node](const auto& v) { return (v->node == node); });
// Node can be run the latest just before it is run the next time!
// Node can be run the latest just before its parents are run the next time!
size_t late = latest;
if (itNode != schedule.end()) {
late = (*itNode)->late - 1;
for (const auto& parent : node->getParents()) {
// Find child node next scheduled position
const auto it = std::find_if(schedule.begin() + elt + 1, itNode,
[parent](const auto& v) { return (v->node == parent); });
AIDGE_INTERNAL_ASSERT(it != schedule.end());
const size_t step = std::distance(schedule.begin(), it);
late = std::min(late, schedule[step]->late - 1);
schedule[step]->laterThan.push_back(schedule[elt]);
}
}
// Node can be run the latest just before its earliest child is run
......@@ -353,6 +370,7 @@ void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::sh
if (it != schedule.end()) {
const size_t step = std::distance(schedule.begin(), it);
late = std::min(late, schedule[step]->late - 1);
schedule[step]->laterThan.push_back(schedule[elt]);
}
}
......@@ -733,42 +751,58 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, bool /*verbose*/, std::
// Thread pooling avoid the overhead of threads creation and deletion for
// each node execution.
ThreadPool pool;
pool.start();
size_t latest = 0;
std::mutex schedulingMutex;
std::vector<int> required(staticSchedule.back()->late + 1, 0);
std::vector<std::atomic<int>> finished(staticSchedule.back()->late + 1);
std::fill(finished.begin(), finished.end(), 0);
std::map<std::shared_ptr<StaticSchedulingElement>, std::atomic<bool>> finished;
while (!staticSchedule.empty()) {
Log::debug("Step {}", latest);
// Run all nodes that must be run at latest
std::vector<std::shared_ptr<StaticSchedulingElement>> mustFinish;
// Run all nodes that must be run at this step: latest (critical nodes)
for (size_t i = 0; i < staticSchedule.size(); ) {
auto runnable = staticSchedule[i];
if (runnable->late == latest) {
// Critical path
pool.queueJob([node = runnable->node, &finished = finished[latest], &schedulingMutex, this]() {
// Wait for potential preceding non-critical nodes to finish
while (true) {
bool ready = true;
for (auto elt : runnable->laterThan) {
ready = ready && finished.at(elt);
}
if (!ready) {
std::this_thread::yield();
}
else {
break;
}
}
// Add the critical node to the thread pool queue, to be run ASAP
finished[runnable] = false;
pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() {
const auto tStart = std::chrono::high_resolution_clock::now();
node->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
++finished;
finished = true;
{
std::unique_lock<std::mutex> lock(schedulingMutex);
mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd));
}
});
staticSchedule.erase(staticSchedule.begin() + i);
++required[latest];
mustFinish.push_back(runnable);
Log::debug(" run critical {}", namePtrTable.at(runnable->node));
// Ensure the following nodes cannot start earlier than next step
for (auto elt : runnable->earlierThan) {
if (elt->early < latest + 1) {
Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1);
elt->early = latest + 1;
AIDGE_INTERNAL_ASSERT(elt->early <= elt->late);
}
}
}
......@@ -784,46 +818,76 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, bool /*verbose*/, std::
}
// If some threads are still available, run next early nodes
while (!pool.busy() && !staticSchedule.empty()) {
auto runnable = staticSchedule.front();
if (runnable->early <= latest) {
pool.queueJob([node = runnable->node, &finished = finished[runnable->late], &schedulingMutex, this]() {
const auto tStart = std::chrono::high_resolution_clock::now();
node->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
++finished;
{
std::unique_lock<std::mutex> lock(schedulingMutex);
mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd));
}
});
staticSchedule.pop_front();
++required[runnable->late];
// These nodes are non-critical, meaning they can still be run at least
// in the next step
for (size_t i = 0; i < staticSchedule.size(); ) {
auto runnable = staticSchedule[i];
if (!pool.busy() && runnable->early <= latest) {
// Check that potential preceding non-critical nodes are finished
bool ready = true;
for (auto elt : runnable->laterThan) {
ready = ready && finished.at(elt);
}
Log::debug(" run {}", namePtrTable.at(runnable->node));
if (ready) {
// All preceding nodes have finished, this node can be run.
// Add the node to the thread pool queue, to be run ASAP
finished[runnable] = false;
pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() {
const auto tStart = std::chrono::high_resolution_clock::now();
node->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
finished = true;
{
std::unique_lock<std::mutex> lock(schedulingMutex);
mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd));
}
});
staticSchedule.erase(staticSchedule.begin() + i);
for (auto elt : runnable->earlierThan) {
if (elt->early < latest + 1) {
Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1);
elt->early = latest + 1;
Log::debug(" run {}", namePtrTable.at(runnable->node));
// Ensure the following nodes cannot start earlier than next step
for (auto elt : runnable->earlierThan) {
if (elt->early < latest + 1) {
Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1);
elt->early = latest + 1;
AIDGE_INTERNAL_ASSERT(elt->early <= elt->late);
}
}
}
else {
// The node cannot be run yet, because preceding nodes are
// still running, move to the next one in schedule
++i;
}
}
else {
// Thread pool is already full or no more node can be run at
// this step (latest)
break;
}
}
// Wait for all nodes that must finish at latest to be finished
while (finished[latest] < required[latest]) {
std::this_thread::yield();
// By scheduling construction, no other node can be started before all
// nodes at latest step are finished
while (true) {
bool ready = true;
for (auto elt : mustFinish) {
ready = ready && finished.at(elt);
}
if (!ready) {
std::this_thread::yield();
}
else {
break;
}
}
++latest;
}
pool.stop();
++mStaticScheduleStep;
if (mStaticScheduleStep == mStaticSchedule.size()) {
mStaticScheduleStep = 0;
......
......@@ -11,7 +11,7 @@
#include "aidge/scheduler/ThreadPool.hpp"
void Aidge::ThreadPool::start(size_t nbThreads) {
Aidge::ThreadPool::ThreadPool(size_t nbThreads) {
for (size_t i = 0; i < nbThreads; ++i) {
mThreads.emplace_back(std::thread(&ThreadPool::threadLoop, this));
}
......@@ -52,7 +52,7 @@ bool Aidge::ThreadPool::busy() {
return poolbusy;
}
void Aidge::ThreadPool::stop() {
Aidge::ThreadPool::~ThreadPool() {
{
std::unique_lock<std::mutex> lock(mQueueMutex);
mTerminate = true;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment