Skip to content
Snippets Groups Projects
Commit 9b1d3208 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added new parallel Scheduler

parent 6f5ce957
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!94Improved scheduling
Pipeline #40984 passed
...@@ -29,7 +29,7 @@ class Node; ...@@ -29,7 +29,7 @@ class Node;
class GraphView; class GraphView;
class SequentialScheduler { class SequentialScheduler {
private: protected:
struct StaticSchedulingElement { struct StaticSchedulingElement {
StaticSchedulingElement( StaticSchedulingElement(
std::shared_ptr<Node> node_, std::shared_ptr<Node> node_,
...@@ -40,6 +40,7 @@ private: ...@@ -40,6 +40,7 @@ private:
std::shared_ptr<Node> node; std::shared_ptr<Node> node;
size_t early; size_t early;
size_t late; size_t late;
std::vector<std::shared_ptr<StaticSchedulingElement>> earlierThan;
}; };
struct SchedulingElement { struct SchedulingElement {
...@@ -67,7 +68,7 @@ public: ...@@ -67,7 +68,7 @@ public:
{ {
// ctor // ctor
}; };
~SequentialScheduler() = default; virtual ~SequentialScheduler() = default;
/** /**
* Generate full static scheduling of the GraphView. * Generate full static scheduling of the GraphView.
...@@ -99,7 +100,7 @@ public: ...@@ -99,7 +100,7 @@ public:
/** /**
* @brief Run the provided Computational Graph with a batch of data * @brief Run the provided Computational Graph with a batch of data
*/ */
void forward(bool forwardDims = true, bool verbose = false, std::vector<std::shared_ptr<Aidge::Tensor>> data = {}); virtual void forward(bool forwardDims = true, bool verbose = false, std::vector<std::shared_ptr<Aidge::Tensor>> data = {});
/** /**
* @brief Save in a Markdown file the static scheduling with early and late relative order for the nodes. * @brief Save in a Markdown file the static scheduling with early and late relative order for the nodes.
...@@ -122,20 +123,20 @@ public: ...@@ -122,20 +123,20 @@ public:
return mGraphView; return mGraphView;
} }
private: protected:
/** /**
* Generate an initial base scheduling for the GraphView. * Generate an initial base scheduling for the GraphView.
* The scheduling is entirely sequential and garanteed to be valid w.r.t. * The scheduling is entirely sequential and garanteed to be valid w.r.t.
* each node producer-consumer model. * each node producer-consumer model.
*/ */
std::vector<StaticSchedulingElement> generateBaseScheduling() const; std::vector<std::shared_ptr<StaticSchedulingElement>> generateBaseScheduling() const;
/** /**
* Fill-in early and late scheduling step from initial base scheduling. * Fill-in early and late scheduling step from initial base scheduling.
* For each node, specifies the earliest and latest possible execution * For each node, specifies the earliest and latest possible execution
* logical step. * logical step.
*/ */
void generateEarlyLateScheduling(std::vector<StaticSchedulingElement>& schedule) const; void generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const;
/** /**
* @brief Set of layers receiving an input from currently processing layers * @brief Set of layers receiving an input from currently processing layers
...@@ -154,9 +155,24 @@ private: ...@@ -154,9 +155,24 @@ private:
/** @brief List of SchedulingElement (i.e: Nodes with their computation time) */ /** @brief List of SchedulingElement (i.e: Nodes with their computation time) */
std::vector<SchedulingElement> mScheduling; std::vector<SchedulingElement> mScheduling;
/** @brief List of nodes ordered by their */ /** @brief List of nodes ordered by their */
std::vector<std::vector<StaticSchedulingElement>> mStaticSchedule; std::vector<std::vector<std::shared_ptr<StaticSchedulingElement>>> mStaticSchedule;
size_t mStaticScheduleStep = 0; size_t mStaticScheduleStep = 0;
}; };
/**
* Multi-threaded parallel scheduler with dynamic scheduling.
*/
class ParallelScheduler : public SequentialScheduler {
public:
ParallelScheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr)
: SequentialScheduler(graphView, upperNode)
{
// ctor
};
~ParallelScheduler() = default;
virtual void forward(bool forwardDims = true, bool verbose = false, std::vector<std::shared_ptr<Aidge::Tensor>> data = {});
};
} // namespace Aidge } // namespace Aidge
#endif /* AIDGE_SCHEDULER_H_ */ #endif /* AIDGE_SCHEDULER_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_THREADPOOL_H_
#define AIDGE_THREADPOOL_H_
#include <thread>
#include <mutex>
#include <queue>
#include <vector>
#include <functional>
#include <condition_variable>
#include <atomic>
namespace Aidge {
class ThreadPool {
public:
void start(size_t nbThreads = std::thread::hardware_concurrency());
void queueJob(const std::function<void()>& job);
void stop();
bool busy();
private:
void threadLoop();
bool mTerminate = false;
std::mutex mQueueMutex;
std::condition_variable mMutexCondition;
std::vector<std::thread> mThreads;
std::queue<std::function<void()>> mJobs;
};
} // namespace Aidge
#endif /* AIDGE_THREADPOOL_H_ */
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
********************************************************************************/ ********************************************************************************/
#include "aidge/scheduler/Scheduler.hpp" #include "aidge/scheduler/Scheduler.hpp"
#include "aidge/scheduler/ThreadPool.hpp"
#include <chrono> #include <chrono>
#include <memory> #include <memory>
...@@ -46,7 +47,7 @@ void Aidge::SequentialScheduler::generateScheduling() { ...@@ -46,7 +47,7 @@ void Aidge::SequentialScheduler::generateScheduling() {
mStaticSchedule.push_back(schedule); mStaticSchedule.push_back(schedule);
} }
std::vector<Aidge::SequentialScheduler::StaticSchedulingElement> Aidge::SequentialScheduler::generateBaseScheduling() const { std::vector<std::shared_ptr<Aidge::SequentialScheduler::StaticSchedulingElement>> Aidge::SequentialScheduler::generateBaseScheduling() const {
// TODO: For loop on the list of node to run // TODO: For loop on the list of node to run
// run sequencially every runnable consumers once // run sequencially every runnable consumers once
// TODO: handle memory allocation in scheduler // TODO: handle memory allocation in scheduler
...@@ -73,7 +74,7 @@ std::vector<Aidge::SequentialScheduler::StaticSchedulingElement> Aidge::Sequenti ...@@ -73,7 +74,7 @@ std::vector<Aidge::SequentialScheduler::StaticSchedulingElement> Aidge::Sequenti
// producers-consumers model! // producers-consumers model!
std::set<std::shared_ptr<Node>> stillConsumers; std::set<std::shared_ptr<Node>> stillConsumers;
std::vector<StaticSchedulingElement> schedule; std::vector<std::shared_ptr<StaticSchedulingElement>> schedule;
do { do {
// 2) From the current consumers list, check if any prior consumer node // 2) From the current consumers list, check if any prior consumer node
...@@ -125,7 +126,7 @@ std::vector<Aidge::SequentialScheduler::StaticSchedulingElement> Aidge::Sequenti ...@@ -125,7 +126,7 @@ std::vector<Aidge::SequentialScheduler::StaticSchedulingElement> Aidge::Sequenti
// Producers are special nodes that generate data on demand. // Producers are special nodes that generate data on demand.
for (const auto& requiredProducer : requiredProducers) { for (const auto& requiredProducer : requiredProducers) {
requiredProducer->getOperator()->updateConsummerProducer(); requiredProducer->getOperator()->updateConsummerProducer();
schedule.push_back(StaticSchedulingElement(requiredProducer)); schedule.push_back(std::make_shared<StaticSchedulingElement>(requiredProducer));
} }
// 5) Find runnable consumers. // 5) Find runnable consumers.
...@@ -194,7 +195,7 @@ std::vector<Aidge::SequentialScheduler::StaticSchedulingElement> Aidge::Sequenti ...@@ -194,7 +195,7 @@ std::vector<Aidge::SequentialScheduler::StaticSchedulingElement> Aidge::Sequenti
for (const auto& runnable : runnableConsumers) { for (const auto& runnable : runnableConsumers) {
Log::debug("Runnable: {}", namePtrTable.at(runnable)); Log::debug("Runnable: {}", namePtrTable.at(runnable));
runnable->getOperator()->updateConsummerProducer(); runnable->getOperator()->updateConsummerProducer();
schedule.push_back(StaticSchedulingElement(runnable)); schedule.push_back(std::make_shared<StaticSchedulingElement>(runnable));
} }
// 7) Update consumers list // 7) Update consumers list
...@@ -301,59 +302,61 @@ std::vector<Aidge::SequentialScheduler::StaticSchedulingElement> Aidge::Sequenti ...@@ -301,59 +302,61 @@ std::vector<Aidge::SequentialScheduler::StaticSchedulingElement> Aidge::Sequenti
return schedule; return schedule;
} }
void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingElement>& schedule) const { void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const {
size_t latest = 0; size_t latest = 0;
// Calculate early (logical) start // Calculate early (logical) start
for (size_t elt = 0; elt < schedule.size(); ++elt) { for (size_t elt = 0; elt < schedule.size(); ++elt) {
const auto node = schedule[elt].node; const auto node = schedule[elt]->node;
const auto itNode = std::find_if(schedule.rend() - elt, schedule.rend(), const auto itNode = std::find_if(schedule.rend() - elt, schedule.rend(),
[node](const auto& v) { return (v.node == node); }); [node](const auto& v) { return (v->node == node); });
// Node can be run the earliest just after it was run the last time! // Node can be run the earliest just after it was run the last time!
size_t early = 0; size_t early = 0;
if (itNode != schedule.rend()) { if (itNode != schedule.rend()) {
early = (*itNode).early + 1; early = (*itNode)->early + 1;
(*itNode)->earlierThan.push_back(schedule[elt]);
} }
// Node can be run the earliest just after its latest parent was run // Node can be run the earliest just after its latest parent was run
for (const auto& parent : node->getParents()) { for (const auto& parent : node->getParents()) {
// Find parent node latest scheduled position // Find parent node latest scheduled position
const auto it = std::find_if(schedule.rend() - elt, schedule.rend(), const auto it = std::find_if(schedule.rend() - elt, schedule.rend(),
[parent](const auto& v) { return (v.node == parent); }); [parent](const auto& v) { return (v->node == parent); });
if (it != schedule.rend()) { if (it != schedule.rend()) {
const size_t step = std::distance(schedule.begin(), it.base()) - 1; const size_t step = std::distance(schedule.begin(), it.base()) - 1;
early = std::max(early, schedule[step].early + 1); early = std::max(early, schedule[step]->early + 1);
schedule[step]->earlierThan.push_back(schedule[elt]);
} }
} }
latest = std::max(latest, early); latest = std::max(latest, early);
schedule[elt].early = early; schedule[elt]->early = early;
} }
// Calculate late (logical) start // Calculate late (logical) start
for (size_t elt = schedule.size(); elt-- != 0; ) { for (size_t elt = schedule.size(); elt-- != 0; ) {
const auto node = schedule[elt].node; const auto node = schedule[elt]->node;
const auto itNode = std::find_if(schedule.begin() + elt + 1, schedule.end(), const auto itNode = std::find_if(schedule.begin() + elt + 1, schedule.end(),
[node](const auto& v) { return (v.node == node); }); [node](const auto& v) { return (v->node == node); });
// Node can be run the latest just before it is run the next time! // Node can be run the latest just before it is run the next time!
size_t late = latest; size_t late = latest;
if (itNode != schedule.end()) { if (itNode != schedule.end()) {
late = (*itNode).late - 1; late = (*itNode)->late - 1;
} }
// Node can be run the latest just before its earliest child is run // Node can be run the latest just before its earliest child is run
for (const auto& child : node->getChildren()) { for (const auto& child : node->getChildren()) {
// Find child node earliest scheduled position // Find child node earliest scheduled position
const auto it = std::find_if(schedule.begin() + elt + 1, schedule.end(), const auto it = std::find_if(schedule.begin() + elt + 1, schedule.end(),
[child](const auto& v) { return (v.node == child); }); [child](const auto& v) { return (v->node == child); });
if (it != schedule.end()) { if (it != schedule.end()) {
const size_t step = std::distance(schedule.begin(), it); const size_t step = std::distance(schedule.begin(), it);
late = std::min(late, schedule[step].late - 1); late = std::min(late, schedule[step]->late - 1);
} }
} }
schedule[elt].late = late; schedule[elt]->late = late;
} }
} }
...@@ -577,12 +580,12 @@ void Aidge::SequentialScheduler::saveStaticSchedulingDiagram(const std::string& ...@@ -577,12 +580,12 @@ void Aidge::SequentialScheduler::saveStaticSchedulingDiagram(const std::string&
for (const auto& schedule : mStaticSchedule) { for (const auto& schedule : mStaticSchedule) {
for (const auto& element : schedule) { for (const auto& element : schedule) {
auto name = namePtrTable.at(element.node); auto name = namePtrTable.at(element->node);
// Mermaid does not allow : character in task title // Mermaid does not allow : character in task title
std::replace(name.begin(), name.end(), ':', '_'); std::replace(name.begin(), name.end(), ':', '_');
fmt::print(fp.get(), "{} :{}, {}\n", fmt::print(fp.get(), "{} :{}, {}\n",
name, element.early, element.late); name, element->early, element->late);
} }
} }
} }
...@@ -593,7 +596,7 @@ void Aidge::SequentialScheduler::saveStaticSchedulingDiagram(const std::string& ...@@ -593,7 +596,7 @@ void Aidge::SequentialScheduler::saveStaticSchedulingDiagram(const std::string&
std::vector<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getStaticScheduling(size_t step) const { std::vector<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getStaticScheduling(size_t step) const {
const auto& staticSchedule = mStaticSchedule.at(step); const auto& staticSchedule = mStaticSchedule.at(step);
std::vector<std::shared_ptr<Node>> schedule; std::vector<std::shared_ptr<Node>> schedule;
std::transform(staticSchedule.begin(), staticSchedule.end(), std::back_inserter(schedule), [](const auto& v) { return v.node; }); std::transform(staticSchedule.begin(), staticSchedule.end(), std::back_inserter(schedule), [](const auto& v) { return v->node; });
return schedule; return schedule;
} }
...@@ -700,3 +703,129 @@ Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler:: ...@@ -700,3 +703,129 @@ Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::
} }
return prior; return prior;
} }
void Aidge::ParallelScheduler::forward(bool forwardDims, bool /*verbose*/, std::vector<std::shared_ptr<Aidge::Tensor>> data) {
// Collect all data input of the graph (that are producers)
if (!data.empty()){
connectInputs(data);
}
// Forward dims (if allowed)
if (forwardDims) {mGraphView->forwardDims(); }
// Generate scheduling *only if empty*
// If scheduling was already generated (in one or several steps, i.e. one or
// several successive call to generateScheduling()), do not generate it twice
if (mStaticSchedule.empty()) {
this->generateScheduling();
}
const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
// Sort static scheduling, the order will be the prefered threads scheduling
// order for non critical nodes
std::deque<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end());
std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
[](const auto& lhs, const auto& rhs) { return ((lhs->early < rhs->early) || (lhs->early == rhs->early && lhs->late < rhs->late)); });
// The thread pool has N threads running to process nodes.
// Thread pooling avoid the overhead of threads creation and deletion for
// each node execution.
ThreadPool pool;
pool.start();
size_t latest = 0;
std::mutex schedulingMutex;
std::vector<int> required(staticSchedule.back()->late + 1, 0);
std::vector<std::atomic<int>> finished(staticSchedule.back()->late + 1);
std::fill(finished.begin(), finished.end(), 0);
while (!staticSchedule.empty()) {
Log::debug("Step {}", latest);
// Run all nodes that must be run at latest
for (size_t i = 0; i < staticSchedule.size(); ) {
auto runnable = staticSchedule[i];
if (runnable->late == latest) {
// Critical path
pool.queueJob([node = runnable->node, &finished = finished[latest], &schedulingMutex, this]() {
const auto tStart = std::chrono::high_resolution_clock::now();
node->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
++finished;
{
std::unique_lock<std::mutex> lock(schedulingMutex);
mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd));
}
});
staticSchedule.erase(staticSchedule.begin() + i);
++required[latest];
Log::debug(" run critical {}", namePtrTable.at(runnable->node));
for (auto elt : runnable->earlierThan) {
if (elt->early < latest + 1) {
Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1);
elt->early = latest + 1;
}
}
}
else if (runnable->early > latest + 1) {
// There cannot be more node that must be run at latest + 1
// latest + 1 and not latest because early may have been updated
// for some elements to latest + 1 (above).
break;
}
else {
++i;
}
}
// If some threads are still available, run next early nodes
while (!pool.busy() && !staticSchedule.empty()) {
auto runnable = staticSchedule.front();
if (runnable->early <= latest) {
pool.queueJob([node = runnable->node, &finished = finished[runnable->late], &schedulingMutex, this]() {
const auto tStart = std::chrono::high_resolution_clock::now();
node->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
++finished;
{
std::unique_lock<std::mutex> lock(schedulingMutex);
mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd));
}
});
staticSchedule.pop_front();
++required[runnable->late];
Log::debug(" run {}", namePtrTable.at(runnable->node));
for (auto elt : runnable->earlierThan) {
if (elt->early < latest + 1) {
Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1);
elt->early = latest + 1;
}
}
}
else {
break;
}
}
// Wait for all nodes that must finish at latest to be finished
while (finished[latest] < required[latest]) {
std::this_thread::yield();
}
++latest;
}
pool.stop();
++mStaticScheduleStep;
if (mStaticScheduleStep == mStaticSchedule.size()) {
mStaticScheduleStep = 0;
}
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/scheduler/ThreadPool.hpp"
void Aidge::ThreadPool::start(size_t nbThreads) {
for (size_t i = 0; i < nbThreads; ++i) {
mThreads.emplace_back(std::thread(&ThreadPool::threadLoop, this));
}
}
void Aidge::ThreadPool::threadLoop() {
while (true) {
std::function<void()> job;
{
std::unique_lock<std::mutex> lock(mQueueMutex);
mMutexCondition.wait(lock, [this] {
return !mJobs.empty() || mTerminate;
});
if (mTerminate) {
return;
}
job = mJobs.front();
mJobs.pop();
}
job();
}
}
void Aidge::ThreadPool::queueJob(const std::function<void()>& job) {
{
std::unique_lock<std::mutex> lock(mQueueMutex);
mJobs.push(job);
}
mMutexCondition.notify_one();
}
bool Aidge::ThreadPool::busy() {
bool poolbusy;
{
std::unique_lock<std::mutex> lock(mQueueMutex);
poolbusy = !mJobs.empty();
}
return poolbusy;
}
void Aidge::ThreadPool::stop() {
{
std::unique_lock<std::mutex> lock(mQueueMutex);
mTerminate = true;
}
mMutexCondition.notify_all();
for (std::thread& active_thread : mThreads) {
active_thread.join();
}
mThreads.clear();
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment