Skip to content
Snippets Groups Projects
Commit 88673b1e authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Scheduler refactor

parent d4aa1fa6
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!94Improved scheduling
This commit is part of merge request !94. Comments created here will be created in the context of that merge request.
......@@ -15,7 +15,7 @@
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
namespace Aidge {
class MetaOperator_Op : public OperatorTensor,
......
......@@ -94,7 +94,9 @@ public:
inline const std::vector<DimSize_t> dims() const noexcept { return mOutputs[0]->dims(); }
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
SET_IMPL_MACRO(Producer_Op, *this, name);
if (Registrar<Producer_Op>::exists(name)) {
SET_IMPL_MACRO(Producer_Op, *this, name);
}
mOutputs[0]->setBackend(name, device);
}
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_PARALLELSCHEDULER_H_
#define AIDGE_PARALLELSCHEDULER_H_
#include <chrono>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include <map>
#include "aidge/scheduler/Scheduler.hpp"
namespace Aidge {
/**
* Multi-threaded parallel scheduler with dynamic scheduling.
*/
class ParallelScheduler : public Scheduler {
public:
ParallelScheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr)
: Scheduler(graphView, upperNode)
{
// ctor
};
~ParallelScheduler() = default;
/**
* @brief Run the provided Computational Graph with a batch of data
*/
virtual void forward(bool forwardDims = true, std::vector<std::shared_ptr<Aidge::Tensor>> data = {});
};
} // namespace Aidge
#endif /* AIDGE_PARALLELSCHEDULER_H_ */
......@@ -28,7 +28,7 @@ namespace Aidge {
class Node;
class GraphView;
class SequentialScheduler {
class Scheduler {
protected:
struct StaticSchedulingElement {
StaticSchedulingElement(
......@@ -63,13 +63,13 @@ protected:
};
public:
SequentialScheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr)
Scheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr)
: mGraphView(graphView),
mUpperNode(upperNode)
{
// ctor
};
virtual ~SequentialScheduler() = default;
virtual ~Scheduler() = default;
/**
* Generate full static scheduling of the GraphView.
......@@ -98,11 +98,6 @@ public:
*/
void connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data);
/**
* @brief Run the provided Computational Graph with a batch of data
*/
virtual void forward(bool forwardDims = true, bool verbose = false, std::vector<std::shared_ptr<Aidge::Tensor>> data = {});
/**
* @brief Save in a Markdown file the static scheduling with early and late relative order for the nodes.
* @param fileName Name of the generated file.
......@@ -159,21 +154,6 @@ protected:
std::vector<std::vector<std::shared_ptr<StaticSchedulingElement>>> mStaticSchedule;
size_t mStaticScheduleStep = 0;
};
/**
* Multi-threaded parallel scheduler with dynamic scheduling.
*/
class ParallelScheduler : public SequentialScheduler {
public:
ParallelScheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr)
: SequentialScheduler(graphView, upperNode)
{
// ctor
};
~ParallelScheduler() = default;
virtual void forward(bool forwardDims = true, bool verbose = false, std::vector<std::shared_ptr<Aidge::Tensor>> data = {});
};
} // namespace Aidge
#endif /* AIDGE_SCHEDULER_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_SEQUENTIALSCHEDULER_H_
#define AIDGE_SEQUENTIALSCHEDULER_H_
#include <chrono>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include <map>
#include "aidge/scheduler/Scheduler.hpp"
namespace Aidge {
/**
* Multi-threaded parallel scheduler with dynamic scheduling.
*/
class SequentialScheduler : public Scheduler {
public:
enum SchedulingPolicy {
Default,
AsSoonAsPossible,
AsLateAsPossible
};
SequentialScheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr)
: Scheduler(graphView, upperNode),
mSchedulingPolicy(Default)
{
// ctor
};
inline void setSchedulingPolicy(SchedulingPolicy policy) {
mSchedulingPolicy = policy;
}
~SequentialScheduler() = default;
/**
* @brief Run the provided Computational Graph with a batch of data
*/
virtual void forward(bool forwardDims = true, std::vector<std::shared_ptr<Aidge::Tensor>> data = {});
private:
SchedulingPolicy mSchedulingPolicy;
};
} // namespace Aidge
#endif /* AIDGE_SEQUENTIALSCHEDULER_H_ */
......@@ -131,18 +131,15 @@ void declare_registrable(py::module& m, const std::string& class_name){
*/
#ifdef PYBIND
#define SET_IMPL_MACRO(T_Op, op, backend_name) \
\
if(Py_IsInitialized()) { \
auto obj = py::cast(&(op)); \
(op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \
} else { \
(op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \
}
if(Py_IsInitialized()) { \
auto obj = py::cast(&(op)); \
(op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \
} else { \
(op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \
}
#else
#define SET_IMPL_MACRO(T_Op, op, backend_name) \
if (Registrar<T_Op>::exists(backend_name)) { \
(op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \
}
(op).setImpl(Registrar<T_Op>::create(backend_name)(op));
#endif
}
......
......@@ -12,19 +12,30 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/scheduler/ParallelScheduler.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/data/Tensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Scheduler(py::module& m){
py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>>(m, "SequentialScheduler")
py::class_<Scheduler, std::shared_ptr<Scheduler>>(m, "Scheduler")
.def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
.def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("verbose")=false, py::arg("data")=std::vector<Tensor>())
.def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name"))
.def("resetScheduling", &SequentialScheduler::resetScheduling)
.def("generate_scheduling", &SequentialScheduler::generateScheduling)
.def("get_static_scheduling", &SequentialScheduler::getStaticScheduling, py::arg("step") = 0)
.def("save_scheduling_diagram", &Scheduler::saveSchedulingDiagram, py::arg("file_name"))
.def("resetScheduling", &Scheduler::resetScheduling)
.def("generate_scheduling", &Scheduler::generateScheduling)
.def("get_static_scheduling", &Scheduler::getStaticScheduling, py::arg("step") = 0)
;
py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>, Scheduler>(m, "SequentialScheduler")
.def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
.def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("data")=std::vector<Tensor>())
;
py::class_<ParallelScheduler, std::shared_ptr<ParallelScheduler>, Scheduler>(m, "ParallelScheduler")
.def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
.def("forward", &ParallelScheduler::forward, py::arg("forward_dims")=true, py::arg("data")=std::vector<Tensor>())
;
}
}
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/scheduler/ParallelScheduler.hpp"
#include "aidge/scheduler/ThreadPool.hpp"
#include <chrono>
#include <memory>
#include <set>
#include <string>
#include <fmt/ranges.h>
#include <fmt/color.h>
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/Memorize.hpp"
#include "aidge/operator/MetaOperator.hpp"
void Aidge::ParallelScheduler::forward(bool forwardDims, std::vector<std::shared_ptr<Aidge::Tensor>> data) {
// Collect all data input of the graph (that are producers)
if (!data.empty()){
connectInputs(data);
}
// Forward dims (if allowed)
if (forwardDims) {mGraphView->forwardDims(); }
// Generate scheduling *only if empty*
// If scheduling was already generated (in one or several steps, i.e. one or
// several successive call to generateScheduling()), do not generate it twice
if (mStaticSchedule.empty()) {
this->generateScheduling();
}
const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
// Sort static scheduling, the order will be the prefered threads scheduling
// order for non critical nodes
std::deque<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end());
std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
[](const auto& lhs, const auto& rhs) { return ((lhs->early < rhs->early) || (lhs->early == rhs->early && lhs->late < rhs->late)); });
// The thread pool has N threads running to process nodes.
// Thread pooling avoid the overhead of threads creation and deletion for
// each node execution.
ThreadPool pool;
size_t latest = 0;
std::mutex schedulingMutex;
std::map<std::shared_ptr<StaticSchedulingElement>, std::atomic<bool>> finished;
while (!staticSchedule.empty()) {
Log::debug("Step {}", latest);
std::vector<std::shared_ptr<StaticSchedulingElement>> mustFinish;
// Run all nodes that must be run at this step: latest (critical nodes)
for (size_t i = 0; i < staticSchedule.size(); ) {
auto runnable = staticSchedule[i];
if (runnable->late == latest) {
// Wait for potential preceding non-critical nodes to finish
while (true) {
bool ready = true;
for (auto elt : runnable->laterThan) {
ready = ready && finished.at(elt);
}
if (!ready) {
std::this_thread::yield();
}
else {
break;
}
}
// Add the critical node to the thread pool queue, to be run ASAP
finished[runnable] = false;
pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() {
const auto tStart = std::chrono::high_resolution_clock::now();
node->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
finished = true;
{
std::unique_lock<std::mutex> lock(schedulingMutex);
mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd));
}
});
staticSchedule.erase(staticSchedule.begin() + i);
mustFinish.push_back(runnable);
Log::debug(" run critical {}", namePtrTable.at(runnable->node));
// Ensure the following nodes cannot start earlier than next step
for (auto elt : runnable->earlierThan) {
if (elt->early < latest + 1) {
Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1);
elt->early = latest + 1;
AIDGE_INTERNAL_ASSERT(elt->early <= elt->late);
}
}
}
else if (runnable->early > latest + 1) {
// There cannot be more node that must be run at latest + 1
// latest + 1 and not latest because early may have been updated
// for some elements to latest + 1 (above).
break;
}
else {
++i;
}
}
// If some threads are still available, run next early nodes
// These nodes are non-critical, meaning they can still be run at least
// in the next step
for (size_t i = 0; i < staticSchedule.size(); ) {
auto runnable = staticSchedule[i];
if (!pool.busy() && runnable->early <= latest) {
// Check that potential preceding non-critical nodes are finished
bool ready = true;
for (auto elt : runnable->laterThan) {
ready = ready && finished.at(elt);
}
if (ready) {
// All preceding nodes have finished, this node can be run.
// Add the node to the thread pool queue, to be run ASAP
finished[runnable] = false;
pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() {
const auto tStart = std::chrono::high_resolution_clock::now();
node->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
finished = true;
{
std::unique_lock<std::mutex> lock(schedulingMutex);
mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd));
}
});
staticSchedule.erase(staticSchedule.begin() + i);
Log::debug(" run {}", namePtrTable.at(runnable->node));
// Ensure the following nodes cannot start earlier than next step
for (auto elt : runnable->earlierThan) {
if (elt->early < latest + 1) {
Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1);
elt->early = latest + 1;
AIDGE_INTERNAL_ASSERT(elt->early <= elt->late);
}
}
}
else {
// The node cannot be run yet, because preceding nodes are
// still running, move to the next one in schedule
++i;
}
}
else {
// Thread pool is already full or no more node can be run at
// this step (latest)
break;
}
}
// Wait for all nodes that must finish at latest to be finished
// By scheduling construction, no other node can be started before all
// nodes at latest step are finished
while (true) {
bool ready = true;
for (auto elt : mustFinish) {
ready = ready && finished.at(elt);
}
if (!ready) {
std::this_thread::yield();
}
else {
break;
}
}
++latest;
}
++mStaticScheduleStep;
if (mStaticScheduleStep == mStaticSchedule.size()) {
mStaticScheduleStep = 0;
}
}
......@@ -10,7 +10,6 @@
********************************************************************************/
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/scheduler/ThreadPool.hpp"
#include <chrono>
#include <memory>
......@@ -28,31 +27,13 @@
#include "aidge/operator/Memorize.hpp"
#include "aidge/operator/MetaOperator.hpp"
void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") {
putchar('[');
int pos = static_cast<int>(barWidth * progress);
for (int i = 0; i < barWidth; ++i) {
if (i <= pos)
putchar('#');
else
putchar(' ');
}
fmt::print("] {}% | {}\r", static_cast<int>(progress * 100), additionalInfo);
fflush(stdout);
}
void Aidge::SequentialScheduler::generateScheduling() {
void Aidge::Scheduler::generateScheduling() {
auto schedule = generateBaseScheduling();
generateEarlyLateScheduling(schedule);
mStaticSchedule.push_back(schedule);
}
std::vector<std::shared_ptr<Aidge::SequentialScheduler::StaticSchedulingElement>> Aidge::SequentialScheduler::generateBaseScheduling() const {
// TODO: For loop on the list of node to run
// run sequencially every runnable consumers once
// TODO: handle memory allocation in scheduler
// TODO: optimize memory usage
std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::Scheduler::generateBaseScheduling() const {
// 1) Setup initial consumers list:
// It is the list of input nodes
std::set<std::shared_ptr<Node>> consumers = mGraphView->inputNodes();
......@@ -302,7 +283,7 @@ std::vector<std::shared_ptr<Aidge::SequentialScheduler::StaticSchedulingElement>
return schedule;
}
void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const {
void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const {
size_t latest = 0;
// Calculate early (logical) start
for (size_t elt = 0; elt < schedule.size(); ++elt) {
......@@ -378,7 +359,7 @@ void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::sh
}
}
void Aidge::SequentialScheduler::resetScheduling() {
void Aidge::Scheduler::resetScheduling() {
for (auto node : mGraphView->getNodes()) {
node->getOperator()->resetConsummerProducer();
}
......@@ -391,7 +372,7 @@ void Aidge::SequentialScheduler::resetScheduling() {
/**
* This version is a simplified version without special handling of concatenation.
*/
Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducers, bool wrapAroundBuffer) const {
Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wrapAroundBuffer) const {
MemoryManager memManager;
for (size_t step = 0; step < mStaticSchedule.size(); ++step) {
......@@ -497,7 +478,7 @@ Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducer
return memManager;
}
void Aidge::SequentialScheduler::connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data){
void Aidge::Scheduler::connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data){
// This version of connect inputs only connects tensor inputs in input data producers.
auto inputNodes = mGraphView->getOrderedInputs();
......@@ -510,49 +491,7 @@ void Aidge::SequentialScheduler::connectInputs(std::vector<std::shared_ptr<Aidge
}
}
void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::vector<std::shared_ptr<Aidge::Tensor>> data) {
// Collect all data input of the graph (that are producers)
if (!data.empty()){
connectInputs(data);
}
// Forward dims (if allowed)
if (forwardDims) {mGraphView->forwardDims(); }
// Generate scheduling *only if empty*
// If scheduling was already generated (in one or several steps, i.e. one or
// several successive call to generateScheduling()), do not generate it twice
if (mStaticSchedule.empty()) {
this->generateScheduling();
}
const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
size_t cpt = 0;
for (const auto& runnable : getStaticScheduling(mStaticScheduleStep)) {
if (verbose)
fmt::print("run: {}\n", namePtrTable.at(runnable));
else
drawProgressBar(static_cast<float>(cpt) / static_cast<float>(mStaticSchedule.size()), 50,
(std::string("running ") + namePtrTable.at(runnable)));
const auto tStart = std::chrono::high_resolution_clock::now();
runnable->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
mScheduling.push_back(SchedulingElement(runnable, tStart, tEnd));
cpt++;
}
if (!verbose) drawProgressBar(1.0, 50, " ");
fmt::print("\n");
++mStaticScheduleStep;
if (mStaticScheduleStep == mStaticSchedule.size()) {
mStaticScheduleStep = 0;
}
}
void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const {
void Aidge::Scheduler::saveSchedulingDiagram(const std::string& fileName) const {
auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose);
if (!fp) {
......@@ -582,7 +521,7 @@ void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileNa
fmt::print(fp.get(), "\n");
}
void Aidge::SequentialScheduler::saveStaticSchedulingDiagram(const std::string& fileName) const {
void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName) const {
auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose);
if (!fp) {
......@@ -611,14 +550,14 @@ void Aidge::SequentialScheduler::saveStaticSchedulingDiagram(const std::string&
fmt::print(fp.get(), "\n");
}
std::vector<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getStaticScheduling(size_t step) const {
std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getStaticScheduling(size_t step) const {
const auto& staticSchedule = mStaticSchedule.at(step);
std::vector<std::shared_ptr<Node>> schedule;
std::transform(staticSchedule.begin(), staticSchedule.end(), std::back_inserter(schedule), [](const auto& v) { return v->node; });
return schedule;
}
std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers(
std::set<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getConsumers(
const std::set<std::shared_ptr<Node>>& producers) const {
std::set<std::shared_ptr<Node>> consumers;
......@@ -635,7 +574,7 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers(
return consumers;
}
Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const {
Aidge::NbElts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const {
const auto parent = node->inputs()[inputIdx];
if (parent.first) {
......@@ -676,7 +615,7 @@ Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared
return 0;
}
Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::getPriorProducersConsumers(
Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersConsumers(
const std::shared_ptr<Node>& node) const
{
PriorProducersConsumers prior;
......@@ -721,175 +660,3 @@ Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::
}
return prior;
}
void Aidge::ParallelScheduler::forward(bool forwardDims, bool /*verbose*/, std::vector<std::shared_ptr<Aidge::Tensor>> data) {
// Collect all data input of the graph (that are producers)
if (!data.empty()){
connectInputs(data);
}
// Forward dims (if allowed)
if (forwardDims) {mGraphView->forwardDims(); }
// Generate scheduling *only if empty*
// If scheduling was already generated (in one or several steps, i.e. one or
// several successive call to generateScheduling()), do not generate it twice
if (mStaticSchedule.empty()) {
this->generateScheduling();
}
const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
// Sort static scheduling, the order will be the prefered threads scheduling
// order for non critical nodes
std::deque<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end());
std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
[](const auto& lhs, const auto& rhs) { return ((lhs->early < rhs->early) || (lhs->early == rhs->early && lhs->late < rhs->late)); });
// The thread pool has N threads running to process nodes.
// Thread pooling avoid the overhead of threads creation and deletion for
// each node execution.
ThreadPool pool;
size_t latest = 0;
std::mutex schedulingMutex;
std::map<std::shared_ptr<StaticSchedulingElement>, std::atomic<bool>> finished;
while (!staticSchedule.empty()) {
Log::debug("Step {}", latest);
std::vector<std::shared_ptr<StaticSchedulingElement>> mustFinish;
// Run all nodes that must be run at this step: latest (critical nodes)
for (size_t i = 0; i < staticSchedule.size(); ) {
auto runnable = staticSchedule[i];
if (runnable->late == latest) {
// Wait for potential preceding non-critical nodes to finish
while (true) {
bool ready = true;
for (auto elt : runnable->laterThan) {
ready = ready && finished.at(elt);
}
if (!ready) {
std::this_thread::yield();
}
else {
break;
}
}
// Add the critical node to the thread pool queue, to be run ASAP
finished[runnable] = false;
pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() {
const auto tStart = std::chrono::high_resolution_clock::now();
node->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
finished = true;
{
std::unique_lock<std::mutex> lock(schedulingMutex);
mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd));
}
});
staticSchedule.erase(staticSchedule.begin() + i);
mustFinish.push_back(runnable);
Log::debug(" run critical {}", namePtrTable.at(runnable->node));
// Ensure the following nodes cannot start earlier than next step
for (auto elt : runnable->earlierThan) {
if (elt->early < latest + 1) {
Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1);
elt->early = latest + 1;
AIDGE_INTERNAL_ASSERT(elt->early <= elt->late);
}
}
}
else if (runnable->early > latest + 1) {
// There cannot be more node that must be run at latest + 1
// latest + 1 and not latest because early may have been updated
// for some elements to latest + 1 (above).
break;
}
else {
++i;
}
}
// If some threads are still available, run next early nodes
// These nodes are non-critical, meaning they can still be run at least
// in the next step
for (size_t i = 0; i < staticSchedule.size(); ) {
auto runnable = staticSchedule[i];
if (!pool.busy() && runnable->early <= latest) {
// Check that potential preceding non-critical nodes are finished
bool ready = true;
for (auto elt : runnable->laterThan) {
ready = ready && finished.at(elt);
}
if (ready) {
// All preceding nodes have finished, this node can be run.
// Add the node to the thread pool queue, to be run ASAP
finished[runnable] = false;
pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() {
const auto tStart = std::chrono::high_resolution_clock::now();
node->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
finished = true;
{
std::unique_lock<std::mutex> lock(schedulingMutex);
mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd));
}
});
staticSchedule.erase(staticSchedule.begin() + i);
Log::debug(" run {}", namePtrTable.at(runnable->node));
// Ensure the following nodes cannot start earlier than next step
for (auto elt : runnable->earlierThan) {
if (elt->early < latest + 1) {
Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1);
elt->early = latest + 1;
AIDGE_INTERNAL_ASSERT(elt->early <= elt->late);
}
}
}
else {
// The node cannot be run yet, because preceding nodes are
// still running, move to the next one in schedule
++i;
}
}
else {
// Thread pool is already full or no more node can be run at
// this step (latest)
break;
}
}
// Wait for all nodes that must finish at latest to be finished
// By scheduling construction, no other node can be started before all
// nodes at latest step are finished
while (true) {
bool ready = true;
for (auto elt : mustFinish) {
ready = ready && finished.at(elt);
}
if (!ready) {
std::this_thread::yield();
}
else {
break;
}
}
++latest;
}
++mStaticScheduleStep;
if (mStaticScheduleStep == mStaticSchedule.size()) {
mStaticScheduleStep = 0;
}
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/scheduler/SequentialScheduler.hpp"
#include <chrono>
#include <memory>
#include <set>
#include <string>
#include <fmt/ranges.h>
#include <fmt/color.h>
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/Memorize.hpp"
#include "aidge/operator/MetaOperator.hpp"
void Aidge::SequentialScheduler::forward(bool forwardDims, std::vector<std::shared_ptr<Aidge::Tensor>> data) {
// Collect all data input of the graph (that are producers)
if (!data.empty()){
connectInputs(data);
}
// Forward dims (if allowed)
if (forwardDims) {mGraphView->forwardDims(); }
// Generate scheduling *only if empty*
// If scheduling was already generated (in one or several steps, i.e. one or
// several successive call to generateScheduling()), do not generate it twice
if (mStaticSchedule.empty()) {
this->generateScheduling();
}
// Sort static scheduling according to the policy
std::vector<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end());
if (mSchedulingPolicy == AsSoonAsPossible) {
std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
[](const auto& lhs, const auto& rhs) { return (lhs->early < rhs->early); });
}
else if (mSchedulingPolicy == AsLateAsPossible) {
std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
[](const auto& lhs, const auto& rhs) { return (lhs->late < rhs->late); });
}
const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
for (const auto& runnable : staticSchedule) {
Log::debug("run: {}", namePtrTable.at(runnable->node));
const auto tStart = std::chrono::high_resolution_clock::now();
runnable->node->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
mScheduling.push_back(SchedulingElement(runnable->node, tStart, tEnd));
}
++mStaticScheduleStep;
if (mStaticScheduleStep == mStaticSchedule.size()) {
mStaticScheduleStep = 0;
}
}
......@@ -25,7 +25,7 @@
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
using namespace Aidge;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment