diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 2992ef652c0ee632ea4d41c2817d35f487ab44c2..8155bb872d93b0032217604eb0294b8f608ec4be 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -29,7 +29,7 @@ class Node; class GraphView; class SequentialScheduler { -private: +protected: struct StaticSchedulingElement { StaticSchedulingElement( std::shared_ptr<Node> node_, @@ -40,6 +40,7 @@ private: std::shared_ptr<Node> node; size_t early; size_t late; + std::vector<std::shared_ptr<StaticSchedulingElement>> earlierThan; }; struct SchedulingElement { @@ -67,7 +68,7 @@ public: { // ctor }; - ~SequentialScheduler() = default; + virtual ~SequentialScheduler() = default; /** * Generate full static scheduling of the GraphView. @@ -99,7 +100,7 @@ public: /** * @brief Run the provided Computational Graph with a batch of data */ - void forward(bool forwardDims = true, bool verbose = false, std::vector<std::shared_ptr<Aidge::Tensor>> data = {}); + virtual void forward(bool forwardDims = true, bool verbose = false, std::vector<std::shared_ptr<Aidge::Tensor>> data = {}); /** * @brief Save in a Markdown file the static scheduling with early and late relative order for the nodes. @@ -122,20 +123,20 @@ public: return mGraphView; } -private: +protected: /** * Generate an initial base scheduling for the GraphView. * The scheduling is entirely sequential and garanteed to be valid w.r.t. * each node producer-consumer model. */ - std::vector<StaticSchedulingElement> generateBaseScheduling() const; + std::vector<std::shared_ptr<StaticSchedulingElement>> generateBaseScheduling() const; /** * Fill-in early and late scheduling step from initial base scheduling. * For each node, specifies the earliest and latest possible execution * logical step. */ - void generateEarlyLateScheduling(std::vector<StaticSchedulingElement>& schedule) const; + void generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const; /** * @brief Set of layers receiving an input from currently processing layers @@ -154,9 +155,24 @@ private: /** @brief List of SchedulingElement (i.e: Nodes with their computation time) */ std::vector<SchedulingElement> mScheduling; /** @brief List of nodes ordered by their */ - std::vector<std::vector<StaticSchedulingElement>> mStaticSchedule; + std::vector<std::vector<std::shared_ptr<StaticSchedulingElement>>> mStaticSchedule; size_t mStaticScheduleStep = 0; }; + +/** + * Multi-threaded parallel scheduler with dynamic scheduling. +*/ +class ParallelScheduler : public SequentialScheduler { +public: + ParallelScheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr) + : SequentialScheduler(graphView, upperNode) + { + // ctor + }; + ~ParallelScheduler() = default; + + virtual void forward(bool forwardDims = true, bool verbose = false, std::vector<std::shared_ptr<Aidge::Tensor>> data = {}); +}; } // namespace Aidge #endif /* AIDGE_SCHEDULER_H_ */ diff --git a/include/aidge/scheduler/ThreadPool.hpp b/include/aidge/scheduler/ThreadPool.hpp new file mode 100644 index 0000000000000000000000000000000000000000..22d0c3f38c8328655d6ab9b4abd0305d6ee55bd8 --- /dev/null +++ b/include/aidge/scheduler/ThreadPool.hpp @@ -0,0 +1,42 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_THREADPOOL_H_ +#define AIDGE_THREADPOOL_H_ + +#include <thread> +#include <mutex> +#include <queue> +#include <vector> +#include <functional> +#include <condition_variable> +#include <atomic> + +namespace Aidge { +class ThreadPool { +public: + void start(size_t nbThreads = std::thread::hardware_concurrency()); + void queueJob(const std::function<void()>& job); + void stop(); + bool busy(); + +private: + void threadLoop(); + + bool mTerminate = false; + std::mutex mQueueMutex; + std::condition_variable mMutexCondition; + std::vector<std::thread> mThreads; + std::queue<std::function<void()>> mJobs; +}; +} // namespace Aidge + +#endif /* AIDGE_THREADPOOL_H_ */ diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 1fba2de64c6ed0e33357f78802c5ae25b125eeff..040752875896687c2b95a7d3d04b6aeb5f4197b5 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -10,6 +10,7 @@ ********************************************************************************/ #include "aidge/scheduler/Scheduler.hpp" +#include "aidge/scheduler/ThreadPool.hpp" #include <chrono> #include <memory> @@ -46,7 +47,7 @@ void Aidge::SequentialScheduler::generateScheduling() { mStaticSchedule.push_back(schedule); } -std::vector<Aidge::SequentialScheduler::StaticSchedulingElement> Aidge::SequentialScheduler::generateBaseScheduling() const { +std::vector<std::shared_ptr<Aidge::SequentialScheduler::StaticSchedulingElement>> Aidge::SequentialScheduler::generateBaseScheduling() const { // TODO: For loop on the list of node to run // run sequencially every runnable consumers once // TODO: handle memory allocation in scheduler @@ -73,7 +74,7 @@ std::vector<Aidge::SequentialScheduler::StaticSchedulingElement> Aidge::Sequenti // producers-consumers model! std::set<std::shared_ptr<Node>> stillConsumers; - std::vector<StaticSchedulingElement> schedule; + std::vector<std::shared_ptr<StaticSchedulingElement>> schedule; do { // 2) From the current consumers list, check if any prior consumer node @@ -125,7 +126,7 @@ std::vector<Aidge::SequentialScheduler::StaticSchedulingElement> Aidge::Sequenti // Producers are special nodes that generate data on demand. for (const auto& requiredProducer : requiredProducers) { requiredProducer->getOperator()->updateConsummerProducer(); - schedule.push_back(StaticSchedulingElement(requiredProducer)); + schedule.push_back(std::make_shared<StaticSchedulingElement>(requiredProducer)); } // 5) Find runnable consumers. @@ -194,7 +195,7 @@ std::vector<Aidge::SequentialScheduler::StaticSchedulingElement> Aidge::Sequenti for (const auto& runnable : runnableConsumers) { Log::debug("Runnable: {}", namePtrTable.at(runnable)); runnable->getOperator()->updateConsummerProducer(); - schedule.push_back(StaticSchedulingElement(runnable)); + schedule.push_back(std::make_shared<StaticSchedulingElement>(runnable)); } // 7) Update consumers list @@ -301,59 +302,61 @@ std::vector<Aidge::SequentialScheduler::StaticSchedulingElement> Aidge::Sequenti return schedule; } -void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingElement>& schedule) const { +void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const { size_t latest = 0; // Calculate early (logical) start for (size_t elt = 0; elt < schedule.size(); ++elt) { - const auto node = schedule[elt].node; + const auto node = schedule[elt]->node; const auto itNode = std::find_if(schedule.rend() - elt, schedule.rend(), - [node](const auto& v) { return (v.node == node); }); + [node](const auto& v) { return (v->node == node); }); // Node can be run the earliest just after it was run the last time! size_t early = 0; if (itNode != schedule.rend()) { - early = (*itNode).early + 1; + early = (*itNode)->early + 1; + (*itNode)->earlierThan.push_back(schedule[elt]); } // Node can be run the earliest just after its latest parent was run for (const auto& parent : node->getParents()) { // Find parent node latest scheduled position const auto it = std::find_if(schedule.rend() - elt, schedule.rend(), - [parent](const auto& v) { return (v.node == parent); }); + [parent](const auto& v) { return (v->node == parent); }); if (it != schedule.rend()) { const size_t step = std::distance(schedule.begin(), it.base()) - 1; - early = std::max(early, schedule[step].early + 1); + early = std::max(early, schedule[step]->early + 1); + schedule[step]->earlierThan.push_back(schedule[elt]); } } latest = std::max(latest, early); - schedule[elt].early = early; + schedule[elt]->early = early; } // Calculate late (logical) start for (size_t elt = schedule.size(); elt-- != 0; ) { - const auto node = schedule[elt].node; + const auto node = schedule[elt]->node; const auto itNode = std::find_if(schedule.begin() + elt + 1, schedule.end(), - [node](const auto& v) { return (v.node == node); }); + [node](const auto& v) { return (v->node == node); }); // Node can be run the latest just before it is run the next time! size_t late = latest; if (itNode != schedule.end()) { - late = (*itNode).late - 1; + late = (*itNode)->late - 1; } // Node can be run the latest just before its earliest child is run for (const auto& child : node->getChildren()) { // Find child node earliest scheduled position const auto it = std::find_if(schedule.begin() + elt + 1, schedule.end(), - [child](const auto& v) { return (v.node == child); }); + [child](const auto& v) { return (v->node == child); }); if (it != schedule.end()) { const size_t step = std::distance(schedule.begin(), it); - late = std::min(late, schedule[step].late - 1); + late = std::min(late, schedule[step]->late - 1); } } - schedule[elt].late = late; + schedule[elt]->late = late; } } @@ -577,12 +580,12 @@ void Aidge::SequentialScheduler::saveStaticSchedulingDiagram(const std::string& for (const auto& schedule : mStaticSchedule) { for (const auto& element : schedule) { - auto name = namePtrTable.at(element.node); + auto name = namePtrTable.at(element->node); // Mermaid does not allow : character in task title std::replace(name.begin(), name.end(), ':', '_'); fmt::print(fp.get(), "{} :{}, {}\n", - name, element.early, element.late); + name, element->early, element->late); } } } @@ -593,7 +596,7 @@ void Aidge::SequentialScheduler::saveStaticSchedulingDiagram(const std::string& std::vector<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getStaticScheduling(size_t step) const { const auto& staticSchedule = mStaticSchedule.at(step); std::vector<std::shared_ptr<Node>> schedule; - std::transform(staticSchedule.begin(), staticSchedule.end(), std::back_inserter(schedule), [](const auto& v) { return v.node; }); + std::transform(staticSchedule.begin(), staticSchedule.end(), std::back_inserter(schedule), [](const auto& v) { return v->node; }); return schedule; } @@ -700,3 +703,129 @@ Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler:: } return prior; } + +void Aidge::ParallelScheduler::forward(bool forwardDims, bool /*verbose*/, std::vector<std::shared_ptr<Aidge::Tensor>> data) { + + // Collect all data input of the graph (that are producers) + if (!data.empty()){ + connectInputs(data); + } + + // Forward dims (if allowed) + if (forwardDims) {mGraphView->forwardDims(); } + + // Generate scheduling *only if empty* + // If scheduling was already generated (in one or several steps, i.e. one or + // several successive call to generateScheduling()), do not generate it twice + if (mStaticSchedule.empty()) { + this->generateScheduling(); + } + + const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); + + // Sort static scheduling, the order will be the prefered threads scheduling + // order for non critical nodes + std::deque<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end()); + std::stable_sort(staticSchedule.begin(), staticSchedule.end(), + [](const auto& lhs, const auto& rhs) { return ((lhs->early < rhs->early) || (lhs->early == rhs->early && lhs->late < rhs->late)); }); + + // The thread pool has N threads running to process nodes. + // Thread pooling avoid the overhead of threads creation and deletion for + // each node execution. + ThreadPool pool; + pool.start(); + + size_t latest = 0; + std::mutex schedulingMutex; + std::vector<int> required(staticSchedule.back()->late + 1, 0); + std::vector<std::atomic<int>> finished(staticSchedule.back()->late + 1); + std::fill(finished.begin(), finished.end(), 0); + + while (!staticSchedule.empty()) { + Log::debug("Step {}", latest); + + // Run all nodes that must be run at latest + for (size_t i = 0; i < staticSchedule.size(); ) { + auto runnable = staticSchedule[i]; + + if (runnable->late == latest) { + // Critical path + pool.queueJob([node = runnable->node, &finished = finished[latest], &schedulingMutex, this]() { + const auto tStart = std::chrono::high_resolution_clock::now(); + node->forward(); + const auto tEnd = std::chrono::high_resolution_clock::now(); + ++finished; + { + std::unique_lock<std::mutex> lock(schedulingMutex); + mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd)); + } + }); + staticSchedule.erase(staticSchedule.begin() + i); + ++required[latest]; + + Log::debug(" run critical {}", namePtrTable.at(runnable->node)); + + for (auto elt : runnable->earlierThan) { + if (elt->early < latest + 1) { + Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1); + elt->early = latest + 1; + } + } + } + else if (runnable->early > latest + 1) { + // There cannot be more node that must be run at latest + 1 + // latest + 1 and not latest because early may have been updated + // for some elements to latest + 1 (above). + break; + } + else { + ++i; + } + } + + // If some threads are still available, run next early nodes + while (!pool.busy() && !staticSchedule.empty()) { + auto runnable = staticSchedule.front(); + if (runnable->early <= latest) { + pool.queueJob([node = runnable->node, &finished = finished[runnable->late], &schedulingMutex, this]() { + const auto tStart = std::chrono::high_resolution_clock::now(); + node->forward(); + const auto tEnd = std::chrono::high_resolution_clock::now(); + ++finished; + { + std::unique_lock<std::mutex> lock(schedulingMutex); + mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd)); + } + }); + staticSchedule.pop_front(); + ++required[runnable->late]; + + Log::debug(" run {}", namePtrTable.at(runnable->node)); + + for (auto elt : runnable->earlierThan) { + if (elt->early < latest + 1) { + Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1); + elt->early = latest + 1; + } + } + } + else { + break; + } + } + + // Wait for all nodes that must finish at latest to be finished + while (finished[latest] < required[latest]) { + std::this_thread::yield(); + } + + ++latest; + } + + pool.stop(); + + ++mStaticScheduleStep; + if (mStaticScheduleStep == mStaticSchedule.size()) { + mStaticScheduleStep = 0; + } +} diff --git a/src/scheduler/ThreadPool.cpp b/src/scheduler/ThreadPool.cpp new file mode 100644 index 0000000000000000000000000000000000000000..52fab31047c4d48fd081020941a7887c27f2b5b1 --- /dev/null +++ b/src/scheduler/ThreadPool.cpp @@ -0,0 +1,65 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/scheduler/ThreadPool.hpp" + +void Aidge::ThreadPool::start(size_t nbThreads) { + for (size_t i = 0; i < nbThreads; ++i) { + mThreads.emplace_back(std::thread(&ThreadPool::threadLoop, this)); + } +} + +void Aidge::ThreadPool::threadLoop() { + while (true) { + std::function<void()> job; + { + std::unique_lock<std::mutex> lock(mQueueMutex); + mMutexCondition.wait(lock, [this] { + return !mJobs.empty() || mTerminate; + }); + if (mTerminate) { + return; + } + job = mJobs.front(); + mJobs.pop(); + } + job(); + } +} + +void Aidge::ThreadPool::queueJob(const std::function<void()>& job) { + { + std::unique_lock<std::mutex> lock(mQueueMutex); + mJobs.push(job); + } + mMutexCondition.notify_one(); +} + +bool Aidge::ThreadPool::busy() { + bool poolbusy; + { + std::unique_lock<std::mutex> lock(mQueueMutex); + poolbusy = !mJobs.empty(); + } + return poolbusy; +} + +void Aidge::ThreadPool::stop() { + { + std::unique_lock<std::mutex> lock(mQueueMutex); + mTerminate = true; + } + mMutexCondition.notify_all(); + for (std::thread& active_thread : mThreads) { + active_thread.join(); + } + mThreads.clear(); +}