Skip to content
Snippets Groups Projects
Commit b033e7aa authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed bugs in parallel scheduler

parent 9b1d3208
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!94Improved scheduling
...@@ -41,6 +41,7 @@ protected: ...@@ -41,6 +41,7 @@ protected:
size_t early; size_t early;
size_t late; size_t late;
std::vector<std::shared_ptr<StaticSchedulingElement>> earlierThan; std::vector<std::shared_ptr<StaticSchedulingElement>> earlierThan;
std::vector<std::shared_ptr<StaticSchedulingElement>> laterThan;
}; };
struct SchedulingElement { struct SchedulingElement {
......
...@@ -23,10 +23,10 @@ ...@@ -23,10 +23,10 @@
namespace Aidge { namespace Aidge {
class ThreadPool { class ThreadPool {
public: public:
void start(size_t nbThreads = std::thread::hardware_concurrency()); ThreadPool(size_t nbThreads = std::thread::hardware_concurrency());
void queueJob(const std::function<void()>& job); void queueJob(const std::function<void()>& job);
void stop();
bool busy(); bool busy();
virtual ~ThreadPool();
private: private:
void threadLoop(); void threadLoop();
......
...@@ -310,11 +310,19 @@ void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::sh ...@@ -310,11 +310,19 @@ void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::sh
const auto itNode = std::find_if(schedule.rend() - elt, schedule.rend(), const auto itNode = std::find_if(schedule.rend() - elt, schedule.rend(),
[node](const auto& v) { return (v->node == node); }); [node](const auto& v) { return (v->node == node); });
// Node can be run the earliest just after it was run the last time! // Node can be run the earliest just after its childs were run the last time!
size_t early = 0; size_t early = 0;
if (itNode != schedule.rend()) { if (itNode != schedule.rend()) {
early = (*itNode)->early + 1; for (const auto& child : node->getChildren()) {
(*itNode)->earlierThan.push_back(schedule[elt]); // Find child node next scheduled position
const auto it = std::find_if(schedule.rend() - elt, itNode,
[child](const auto& v) { return (v->node == child); });
AIDGE_INTERNAL_ASSERT(it != schedule.rend());
const size_t step = std::distance(schedule.begin(), it.base()) - 1;
early = std::max(early, schedule[step]->early + 1);
schedule[step]->earlierThan.push_back(schedule[elt]);
}
} }
// Node can be run the earliest just after its latest parent was run // Node can be run the earliest just after its latest parent was run
...@@ -339,10 +347,19 @@ void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::sh ...@@ -339,10 +347,19 @@ void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::sh
const auto itNode = std::find_if(schedule.begin() + elt + 1, schedule.end(), const auto itNode = std::find_if(schedule.begin() + elt + 1, schedule.end(),
[node](const auto& v) { return (v->node == node); }); [node](const auto& v) { return (v->node == node); });
// Node can be run the latest just before it is run the next time! // Node can be run the latest just before its parents are run the next time!
size_t late = latest; size_t late = latest;
if (itNode != schedule.end()) { if (itNode != schedule.end()) {
late = (*itNode)->late - 1; for (const auto& parent : node->getParents()) {
// Find child node next scheduled position
const auto it = std::find_if(schedule.begin() + elt + 1, itNode,
[parent](const auto& v) { return (v->node == parent); });
AIDGE_INTERNAL_ASSERT(it != schedule.end());
const size_t step = std::distance(schedule.begin(), it);
late = std::min(late, schedule[step]->late - 1);
schedule[step]->laterThan.push_back(schedule[elt]);
}
} }
// Node can be run the latest just before its earliest child is run // Node can be run the latest just before its earliest child is run
...@@ -353,6 +370,7 @@ void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::sh ...@@ -353,6 +370,7 @@ void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::sh
if (it != schedule.end()) { if (it != schedule.end()) {
const size_t step = std::distance(schedule.begin(), it); const size_t step = std::distance(schedule.begin(), it);
late = std::min(late, schedule[step]->late - 1); late = std::min(late, schedule[step]->late - 1);
schedule[step]->laterThan.push_back(schedule[elt]);
} }
} }
...@@ -733,42 +751,58 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, bool /*verbose*/, std:: ...@@ -733,42 +751,58 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, bool /*verbose*/, std::
// Thread pooling avoid the overhead of threads creation and deletion for // Thread pooling avoid the overhead of threads creation and deletion for
// each node execution. // each node execution.
ThreadPool pool; ThreadPool pool;
pool.start();
size_t latest = 0; size_t latest = 0;
std::mutex schedulingMutex; std::mutex schedulingMutex;
std::vector<int> required(staticSchedule.back()->late + 1, 0); std::map<std::shared_ptr<StaticSchedulingElement>, std::atomic<bool>> finished;
std::vector<std::atomic<int>> finished(staticSchedule.back()->late + 1);
std::fill(finished.begin(), finished.end(), 0);
while (!staticSchedule.empty()) { while (!staticSchedule.empty()) {
Log::debug("Step {}", latest); Log::debug("Step {}", latest);
// Run all nodes that must be run at latest std::vector<std::shared_ptr<StaticSchedulingElement>> mustFinish;
// Run all nodes that must be run at this step: latest (critical nodes)
for (size_t i = 0; i < staticSchedule.size(); ) { for (size_t i = 0; i < staticSchedule.size(); ) {
auto runnable = staticSchedule[i]; auto runnable = staticSchedule[i];
if (runnable->late == latest) { if (runnable->late == latest) {
// Critical path // Wait for potential preceding non-critical nodes to finish
pool.queueJob([node = runnable->node, &finished = finished[latest], &schedulingMutex, this]() { while (true) {
bool ready = true;
for (auto elt : runnable->laterThan) {
ready = ready && finished.at(elt);
}
if (!ready) {
std::this_thread::yield();
}
else {
break;
}
}
// Add the critical node to the thread pool queue, to be run ASAP
finished[runnable] = false;
pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() {
const auto tStart = std::chrono::high_resolution_clock::now(); const auto tStart = std::chrono::high_resolution_clock::now();
node->forward(); node->forward();
const auto tEnd = std::chrono::high_resolution_clock::now(); const auto tEnd = std::chrono::high_resolution_clock::now();
++finished; finished = true;
{ {
std::unique_lock<std::mutex> lock(schedulingMutex); std::unique_lock<std::mutex> lock(schedulingMutex);
mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd)); mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd));
} }
}); });
staticSchedule.erase(staticSchedule.begin() + i); staticSchedule.erase(staticSchedule.begin() + i);
++required[latest]; mustFinish.push_back(runnable);
Log::debug(" run critical {}", namePtrTable.at(runnable->node)); Log::debug(" run critical {}", namePtrTable.at(runnable->node));
// Ensure the following nodes cannot start earlier than next step
for (auto elt : runnable->earlierThan) { for (auto elt : runnable->earlierThan) {
if (elt->early < latest + 1) { if (elt->early < latest + 1) {
Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1); Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1);
elt->early = latest + 1; elt->early = latest + 1;
AIDGE_INTERNAL_ASSERT(elt->early <= elt->late);
} }
} }
} }
...@@ -784,46 +818,76 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, bool /*verbose*/, std:: ...@@ -784,46 +818,76 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, bool /*verbose*/, std::
} }
// If some threads are still available, run next early nodes // If some threads are still available, run next early nodes
while (!pool.busy() && !staticSchedule.empty()) { // These nodes are non-critical, meaning they can still be run at least
auto runnable = staticSchedule.front(); // in the next step
if (runnable->early <= latest) { for (size_t i = 0; i < staticSchedule.size(); ) {
pool.queueJob([node = runnable->node, &finished = finished[runnable->late], &schedulingMutex, this]() { auto runnable = staticSchedule[i];
const auto tStart = std::chrono::high_resolution_clock::now(); if (!pool.busy() && runnable->early <= latest) {
node->forward(); // Check that potential preceding non-critical nodes are finished
const auto tEnd = std::chrono::high_resolution_clock::now(); bool ready = true;
++finished; for (auto elt : runnable->laterThan) {
{ ready = ready && finished.at(elt);
std::unique_lock<std::mutex> lock(schedulingMutex); }
mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd));
}
});
staticSchedule.pop_front();
++required[runnable->late];
Log::debug(" run {}", namePtrTable.at(runnable->node)); if (ready) {
// All preceding nodes have finished, this node can be run.
// Add the node to the thread pool queue, to be run ASAP
finished[runnable] = false;
pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() {
const auto tStart = std::chrono::high_resolution_clock::now();
node->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
finished = true;
{
std::unique_lock<std::mutex> lock(schedulingMutex);
mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd));
}
});
staticSchedule.erase(staticSchedule.begin() + i);
for (auto elt : runnable->earlierThan) { Log::debug(" run {}", namePtrTable.at(runnable->node));
if (elt->early < latest + 1) {
Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1); // Ensure the following nodes cannot start earlier than next step
elt->early = latest + 1; for (auto elt : runnable->earlierThan) {
if (elt->early < latest + 1) {
Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1);
elt->early = latest + 1;
AIDGE_INTERNAL_ASSERT(elt->early <= elt->late);
}
} }
} }
else {
// The node cannot be run yet, because preceding nodes are
// still running, move to the next one in schedule
++i;
}
} }
else { else {
// Thread pool is already full or no more node can be run at
// this step (latest)
break; break;
} }
} }
// Wait for all nodes that must finish at latest to be finished // Wait for all nodes that must finish at latest to be finished
while (finished[latest] < required[latest]) { // By scheduling construction, no other node can be started before all
std::this_thread::yield(); // nodes at latest step are finished
while (true) {
bool ready = true;
for (auto elt : mustFinish) {
ready = ready && finished.at(elt);
}
if (!ready) {
std::this_thread::yield();
}
else {
break;
}
} }
++latest; ++latest;
} }
pool.stop();
++mStaticScheduleStep; ++mStaticScheduleStep;
if (mStaticScheduleStep == mStaticSchedule.size()) { if (mStaticScheduleStep == mStaticSchedule.size()) {
mStaticScheduleStep = 0; mStaticScheduleStep = 0;
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "aidge/scheduler/ThreadPool.hpp" #include "aidge/scheduler/ThreadPool.hpp"
void Aidge::ThreadPool::start(size_t nbThreads) { Aidge::ThreadPool::ThreadPool(size_t nbThreads) {
for (size_t i = 0; i < nbThreads; ++i) { for (size_t i = 0; i < nbThreads; ++i) {
mThreads.emplace_back(std::thread(&ThreadPool::threadLoop, this)); mThreads.emplace_back(std::thread(&ThreadPool::threadLoop, this));
} }
...@@ -52,7 +52,7 @@ bool Aidge::ThreadPool::busy() { ...@@ -52,7 +52,7 @@ bool Aidge::ThreadPool::busy() {
return poolbusy; return poolbusy;
} }
void Aidge::ThreadPool::stop() { Aidge::ThreadPool::~ThreadPool() {
{ {
std::unique_lock<std::mutex> lock(mQueueMutex); std::unique_lock<std::mutex> lock(mQueueMutex);
mTerminate = true; mTerminate = true;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment