diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 7f36eca2c4586f61f72e0d842d2d576450cd1596..ce328c23fe1abbc0fb4716c73b842f16afe23cc6 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -15,7 +15,7 @@ #include "aidge/operator/OperatorTensor.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/graph/OpArgs.hpp" -#include "aidge/scheduler/Scheduler.hpp" +#include "aidge/scheduler/SequentialScheduler.hpp" namespace Aidge { class MetaOperator_Op : public OperatorTensor, diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index c9b1f6e4aa5d82006d4bed880151ac1a22a4882b..e5e2578e89371ee6897fc413ec4f21e402f45bd5 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -94,7 +94,9 @@ public: inline const std::vector<DimSize_t> dims() const noexcept { return mOutputs[0]->dims(); } void setBackend(const std::string& name, DeviceIdx_t device = 0) override { - SET_IMPL_MACRO(Producer_Op, *this, name); + if (Registrar<Producer_Op>::exists(name)) { + SET_IMPL_MACRO(Producer_Op, *this, name); + } mOutputs[0]->setBackend(name, device); } diff --git a/include/aidge/scheduler/ParallelScheduler.hpp b/include/aidge/scheduler/ParallelScheduler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d471c65ff2d3e8a81c3992d1df06ba387559025e --- /dev/null +++ b/include/aidge/scheduler/ParallelScheduler.hpp @@ -0,0 +1,44 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_PARALLELSCHEDULER_H_ +#define AIDGE_PARALLELSCHEDULER_H_ + +#include <chrono> +#include <memory> +#include <set> +#include <string> +#include <vector> +#include <map> + +#include "aidge/scheduler/Scheduler.hpp" + +namespace Aidge { +/** + * Multi-threaded parallel scheduler with dynamic scheduling. +*/ +class ParallelScheduler : public Scheduler { +public: + ParallelScheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr) + : Scheduler(graphView, upperNode) + { + // ctor + }; + ~ParallelScheduler() = default; + + /** + * @brief Run the provided Computational Graph with a batch of data + */ + virtual void forward(bool forwardDims = true, std::vector<std::shared_ptr<Aidge::Tensor>> data = {}); +}; +} // namespace Aidge + +#endif /* AIDGE_PARALLELSCHEDULER_H_ */ diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index df5ca3d1d84e9108d88f2462ed46f130efb52d8f..e0284f0fbb8debaefc056b94b34bf9e8a0675b1f 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -28,7 +28,7 @@ namespace Aidge { class Node; class GraphView; -class SequentialScheduler { +class Scheduler { protected: struct StaticSchedulingElement { StaticSchedulingElement( @@ -63,13 +63,13 @@ protected: }; public: - SequentialScheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr) + Scheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr) : mGraphView(graphView), mUpperNode(upperNode) { // ctor }; - virtual ~SequentialScheduler() = default; + virtual ~Scheduler() = default; /** * Generate full static scheduling of the GraphView. @@ -98,11 +98,6 @@ public: */ void connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data); - /** - * @brief Run the provided Computational Graph with a batch of data - */ - virtual void forward(bool forwardDims = true, bool verbose = false, std::vector<std::shared_ptr<Aidge::Tensor>> data = {}); - /** * @brief Save in a Markdown file the static scheduling with early and late relative order for the nodes. * @param fileName Name of the generated file. @@ -159,21 +154,6 @@ protected: std::vector<std::vector<std::shared_ptr<StaticSchedulingElement>>> mStaticSchedule; size_t mStaticScheduleStep = 0; }; - -/** - * Multi-threaded parallel scheduler with dynamic scheduling. -*/ -class ParallelScheduler : public SequentialScheduler { -public: - ParallelScheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr) - : SequentialScheduler(graphView, upperNode) - { - // ctor - }; - ~ParallelScheduler() = default; - - virtual void forward(bool forwardDims = true, bool verbose = false, std::vector<std::shared_ptr<Aidge::Tensor>> data = {}); -}; } // namespace Aidge #endif /* AIDGE_SCHEDULER_H_ */ diff --git a/include/aidge/scheduler/SequentialScheduler.hpp b/include/aidge/scheduler/SequentialScheduler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7a854e5d4902801661a3e9731f6d16309c1f17fd --- /dev/null +++ b/include/aidge/scheduler/SequentialScheduler.hpp @@ -0,0 +1,57 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_SEQUENTIALSCHEDULER_H_ +#define AIDGE_SEQUENTIALSCHEDULER_H_ + +#include <chrono> +#include <memory> +#include <set> +#include <string> +#include <vector> +#include <map> + +#include "aidge/scheduler/Scheduler.hpp" + +namespace Aidge { +/** + * Multi-threaded parallel scheduler with dynamic scheduling. +*/ +class SequentialScheduler : public Scheduler { +public: + enum SchedulingPolicy { + Default, + AsSoonAsPossible, + AsLateAsPossible + }; + + SequentialScheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr) + : Scheduler(graphView, upperNode), + mSchedulingPolicy(Default) + { + // ctor + }; + inline void setSchedulingPolicy(SchedulingPolicy policy) { + mSchedulingPolicy = policy; + } + ~SequentialScheduler() = default; + + /** + * @brief Run the provided Computational Graph with a batch of data + */ + virtual void forward(bool forwardDims = true, std::vector<std::shared_ptr<Aidge::Tensor>> data = {}); + +private: + SchedulingPolicy mSchedulingPolicy; +}; +} // namespace Aidge + +#endif /* AIDGE_SEQUENTIALSCHEDULER_H_ */ diff --git a/include/aidge/utils/Registrar.hpp b/include/aidge/utils/Registrar.hpp index a5bd260ec189ac998134b738ca1ae757f2a0038c..6e16eda0631c8e6bc1b342c215518a35672a5020 100644 --- a/include/aidge/utils/Registrar.hpp +++ b/include/aidge/utils/Registrar.hpp @@ -131,18 +131,15 @@ void declare_registrable(py::module& m, const std::string& class_name){ */ #ifdef PYBIND #define SET_IMPL_MACRO(T_Op, op, backend_name) \ - \ - if(Py_IsInitialized()) { \ - auto obj = py::cast(&(op)); \ - (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \ - } else { \ - (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \ - } + if(Py_IsInitialized()) { \ + auto obj = py::cast(&(op)); \ + (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \ + } else { \ + (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \ + } #else #define SET_IMPL_MACRO(T_Op, op, backend_name) \ - if (Registrar<T_Op>::exists(backend_name)) { \ - (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \ - } + (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); #endif } diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp index 4f1b444f997831acefd4b1e24344f6395a46eadc..8e23e9e7cdc5c435f505a234003ebb794137e3ae 100644 --- a/python_binding/scheduler/pybind_Scheduler.cpp +++ b/python_binding/scheduler/pybind_Scheduler.cpp @@ -12,19 +12,30 @@ #include <pybind11/pybind11.h> #include <pybind11/stl.h> #include "aidge/scheduler/Scheduler.hpp" +#include "aidge/scheduler/SequentialScheduler.hpp" +#include "aidge/scheduler/ParallelScheduler.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/data/Tensor.hpp" namespace py = pybind11; namespace Aidge { void init_Scheduler(py::module& m){ - py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>>(m, "SequentialScheduler") + py::class_<Scheduler, std::shared_ptr<Scheduler>>(m, "Scheduler") .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) - .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("verbose")=false, py::arg("data")=std::vector<Tensor>()) - .def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name")) - .def("resetScheduling", &SequentialScheduler::resetScheduling) - .def("generate_scheduling", &SequentialScheduler::generateScheduling) - .def("get_static_scheduling", &SequentialScheduler::getStaticScheduling, py::arg("step") = 0) + .def("save_scheduling_diagram", &Scheduler::saveSchedulingDiagram, py::arg("file_name")) + .def("resetScheduling", &Scheduler::resetScheduling) + .def("generate_scheduling", &Scheduler::generateScheduling) + .def("get_static_scheduling", &Scheduler::getStaticScheduling, py::arg("step") = 0) + ; + + py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>, Scheduler>(m, "SequentialScheduler") + .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) + .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("data")=std::vector<Tensor>()) + ; + + py::class_<ParallelScheduler, std::shared_ptr<ParallelScheduler>, Scheduler>(m, "ParallelScheduler") + .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) + .def("forward", &ParallelScheduler::forward, py::arg("forward_dims")=true, py::arg("data")=std::vector<Tensor>()) ; } } diff --git a/src/scheduler/ParallelScheduler.cpp b/src/scheduler/ParallelScheduler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1dd13fe2100122002d4ed068ada4851b1bfba463 --- /dev/null +++ b/src/scheduler/ParallelScheduler.cpp @@ -0,0 +1,200 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/scheduler/ParallelScheduler.hpp" +#include "aidge/scheduler/ThreadPool.hpp" + +#include <chrono> +#include <memory> +#include <set> +#include <string> + +#include <fmt/ranges.h> +#include <fmt/color.h> + +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/Types.h" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/operator/Memorize.hpp" +#include "aidge/operator/MetaOperator.hpp" + +void Aidge::ParallelScheduler::forward(bool forwardDims, std::vector<std::shared_ptr<Aidge::Tensor>> data) { + // Collect all data input of the graph (that are producers) + if (!data.empty()){ + connectInputs(data); + } + + // Forward dims (if allowed) + if (forwardDims) {mGraphView->forwardDims(); } + + // Generate scheduling *only if empty* + // If scheduling was already generated (in one or several steps, i.e. one or + // several successive call to generateScheduling()), do not generate it twice + if (mStaticSchedule.empty()) { + this->generateScheduling(); + } + + const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); + + // Sort static scheduling, the order will be the prefered threads scheduling + // order for non critical nodes + std::deque<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end()); + std::stable_sort(staticSchedule.begin(), staticSchedule.end(), + [](const auto& lhs, const auto& rhs) { return ((lhs->early < rhs->early) || (lhs->early == rhs->early && lhs->late < rhs->late)); }); + + // The thread pool has N threads running to process nodes. + // Thread pooling avoid the overhead of threads creation and deletion for + // each node execution. + ThreadPool pool; + + size_t latest = 0; + std::mutex schedulingMutex; + std::map<std::shared_ptr<StaticSchedulingElement>, std::atomic<bool>> finished; + + while (!staticSchedule.empty()) { + Log::debug("Step {}", latest); + + std::vector<std::shared_ptr<StaticSchedulingElement>> mustFinish; + + // Run all nodes that must be run at this step: latest (critical nodes) + for (size_t i = 0; i < staticSchedule.size(); ) { + auto runnable = staticSchedule[i]; + + if (runnable->late == latest) { + // Wait for potential preceding non-critical nodes to finish + while (true) { + bool ready = true; + for (auto elt : runnable->laterThan) { + ready = ready && finished.at(elt); + } + if (!ready) { + std::this_thread::yield(); + } + else { + break; + } + } + + // Add the critical node to the thread pool queue, to be run ASAP + finished[runnable] = false; + pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() { + const auto tStart = std::chrono::high_resolution_clock::now(); + node->forward(); + const auto tEnd = std::chrono::high_resolution_clock::now(); + finished = true; + { + std::unique_lock<std::mutex> lock(schedulingMutex); + mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd)); + } + }); + staticSchedule.erase(staticSchedule.begin() + i); + mustFinish.push_back(runnable); + + Log::debug(" run critical {}", namePtrTable.at(runnable->node)); + + // Ensure the following nodes cannot start earlier than next step + for (auto elt : runnable->earlierThan) { + if (elt->early < latest + 1) { + Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1); + elt->early = latest + 1; + AIDGE_INTERNAL_ASSERT(elt->early <= elt->late); + } + } + } + else if (runnable->early > latest + 1) { + // There cannot be more node that must be run at latest + 1 + // latest + 1 and not latest because early may have been updated + // for some elements to latest + 1 (above). + break; + } + else { + ++i; + } + } + + // If some threads are still available, run next early nodes + // These nodes are non-critical, meaning they can still be run at least + // in the next step + for (size_t i = 0; i < staticSchedule.size(); ) { + auto runnable = staticSchedule[i]; + if (!pool.busy() && runnable->early <= latest) { + // Check that potential preceding non-critical nodes are finished + bool ready = true; + for (auto elt : runnable->laterThan) { + ready = ready && finished.at(elt); + } + + if (ready) { + // All preceding nodes have finished, this node can be run. + // Add the node to the thread pool queue, to be run ASAP + finished[runnable] = false; + pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() { + const auto tStart = std::chrono::high_resolution_clock::now(); + node->forward(); + const auto tEnd = std::chrono::high_resolution_clock::now(); + finished = true; + { + std::unique_lock<std::mutex> lock(schedulingMutex); + mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd)); + } + }); + staticSchedule.erase(staticSchedule.begin() + i); + + Log::debug(" run {}", namePtrTable.at(runnable->node)); + + // Ensure the following nodes cannot start earlier than next step + for (auto elt : runnable->earlierThan) { + if (elt->early < latest + 1) { + Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1); + elt->early = latest + 1; + AIDGE_INTERNAL_ASSERT(elt->early <= elt->late); + } + } + } + else { + // The node cannot be run yet, because preceding nodes are + // still running, move to the next one in schedule + ++i; + } + } + else { + // Thread pool is already full or no more node can be run at + // this step (latest) + break; + } + } + + // Wait for all nodes that must finish at latest to be finished + // By scheduling construction, no other node can be started before all + // nodes at latest step are finished + while (true) { + bool ready = true; + for (auto elt : mustFinish) { + ready = ready && finished.at(elt); + } + if (!ready) { + std::this_thread::yield(); + } + else { + break; + } + } + + ++latest; + } + + ++mStaticScheduleStep; + if (mStaticScheduleStep == mStaticSchedule.size()) { + mStaticScheduleStep = 0; + } +} diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 11715901c067cb85f8df9e85d82b94b0b5627745..639375902286a766cf084b61d822913677794e36 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -10,7 +10,6 @@ ********************************************************************************/ #include "aidge/scheduler/Scheduler.hpp" -#include "aidge/scheduler/ThreadPool.hpp" #include <chrono> #include <memory> @@ -28,31 +27,13 @@ #include "aidge/operator/Memorize.hpp" #include "aidge/operator/MetaOperator.hpp" -void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") { - putchar('['); - int pos = static_cast<int>(barWidth * progress); - for (int i = 0; i < barWidth; ++i) { - if (i <= pos) - putchar('#'); - else - putchar(' '); - } - fmt::print("] {}% | {}\r", static_cast<int>(progress * 100), additionalInfo); - fflush(stdout); -} - -void Aidge::SequentialScheduler::generateScheduling() { +void Aidge::Scheduler::generateScheduling() { auto schedule = generateBaseScheduling(); generateEarlyLateScheduling(schedule); mStaticSchedule.push_back(schedule); } -std::vector<std::shared_ptr<Aidge::SequentialScheduler::StaticSchedulingElement>> Aidge::SequentialScheduler::generateBaseScheduling() const { - // TODO: For loop on the list of node to run - // run sequencially every runnable consumers once - // TODO: handle memory allocation in scheduler - // TODO: optimize memory usage - +std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::Scheduler::generateBaseScheduling() const { // 1) Setup initial consumers list: // It is the list of input nodes std::set<std::shared_ptr<Node>> consumers = mGraphView->inputNodes(); @@ -302,7 +283,7 @@ std::vector<std::shared_ptr<Aidge::SequentialScheduler::StaticSchedulingElement> return schedule; } -void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const { +void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const { size_t latest = 0; // Calculate early (logical) start for (size_t elt = 0; elt < schedule.size(); ++elt) { @@ -378,7 +359,7 @@ void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::sh } } -void Aidge::SequentialScheduler::resetScheduling() { +void Aidge::Scheduler::resetScheduling() { for (auto node : mGraphView->getNodes()) { node->getOperator()->resetConsummerProducer(); } @@ -391,7 +372,7 @@ void Aidge::SequentialScheduler::resetScheduling() { /** * This version is a simplified version without special handling of concatenation. */ -Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducers, bool wrapAroundBuffer) const { +Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wrapAroundBuffer) const { MemoryManager memManager; for (size_t step = 0; step < mStaticSchedule.size(); ++step) { @@ -497,7 +478,7 @@ Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducer return memManager; } -void Aidge::SequentialScheduler::connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data){ +void Aidge::Scheduler::connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data){ // This version of connect inputs only connects tensor inputs in input data producers. auto inputNodes = mGraphView->getOrderedInputs(); @@ -510,49 +491,7 @@ void Aidge::SequentialScheduler::connectInputs(std::vector<std::shared_ptr<Aidge } } - -void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::vector<std::shared_ptr<Aidge::Tensor>> data) { - - // Collect all data input of the graph (that are producers) - if (!data.empty()){ - connectInputs(data); - } - - // Forward dims (if allowed) - if (forwardDims) {mGraphView->forwardDims(); } - - // Generate scheduling *only if empty* - // If scheduling was already generated (in one or several steps, i.e. one or - // several successive call to generateScheduling()), do not generate it twice - if (mStaticSchedule.empty()) { - this->generateScheduling(); - } - - const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); - - size_t cpt = 0; - for (const auto& runnable : getStaticScheduling(mStaticScheduleStep)) { - if (verbose) - fmt::print("run: {}\n", namePtrTable.at(runnable)); - else - drawProgressBar(static_cast<float>(cpt) / static_cast<float>(mStaticSchedule.size()), 50, - (std::string("running ") + namePtrTable.at(runnable))); - const auto tStart = std::chrono::high_resolution_clock::now(); - runnable->forward(); - const auto tEnd = std::chrono::high_resolution_clock::now(); - mScheduling.push_back(SchedulingElement(runnable, tStart, tEnd)); - cpt++; - } - if (!verbose) drawProgressBar(1.0, 50, " "); - fmt::print("\n"); - - ++mStaticScheduleStep; - if (mStaticScheduleStep == mStaticSchedule.size()) { - mStaticScheduleStep = 0; - } -} - -void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const { +void Aidge::Scheduler::saveSchedulingDiagram(const std::string& fileName) const { auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose); if (!fp) { @@ -582,7 +521,7 @@ void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileNa fmt::print(fp.get(), "\n"); } -void Aidge::SequentialScheduler::saveStaticSchedulingDiagram(const std::string& fileName) const { +void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName) const { auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose); if (!fp) { @@ -611,14 +550,14 @@ void Aidge::SequentialScheduler::saveStaticSchedulingDiagram(const std::string& fmt::print(fp.get(), "\n"); } -std::vector<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getStaticScheduling(size_t step) const { +std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getStaticScheduling(size_t step) const { const auto& staticSchedule = mStaticSchedule.at(step); std::vector<std::shared_ptr<Node>> schedule; std::transform(staticSchedule.begin(), staticSchedule.end(), std::back_inserter(schedule), [](const auto& v) { return v->node; }); return schedule; } -std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers( +std::set<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getConsumers( const std::set<std::shared_ptr<Node>>& producers) const { std::set<std::shared_ptr<Node>> consumers; @@ -635,7 +574,7 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers( return consumers; } -Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const { +Aidge::NbElts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const { const auto parent = node->inputs()[inputIdx]; if (parent.first) { @@ -676,7 +615,7 @@ Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared return 0; } -Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::getPriorProducersConsumers( +Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersConsumers( const std::shared_ptr<Node>& node) const { PriorProducersConsumers prior; @@ -721,175 +660,3 @@ Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler:: } return prior; } - -void Aidge::ParallelScheduler::forward(bool forwardDims, bool /*verbose*/, std::vector<std::shared_ptr<Aidge::Tensor>> data) { - - // Collect all data input of the graph (that are producers) - if (!data.empty()){ - connectInputs(data); - } - - // Forward dims (if allowed) - if (forwardDims) {mGraphView->forwardDims(); } - - // Generate scheduling *only if empty* - // If scheduling was already generated (in one or several steps, i.e. one or - // several successive call to generateScheduling()), do not generate it twice - if (mStaticSchedule.empty()) { - this->generateScheduling(); - } - - const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); - - // Sort static scheduling, the order will be the prefered threads scheduling - // order for non critical nodes - std::deque<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end()); - std::stable_sort(staticSchedule.begin(), staticSchedule.end(), - [](const auto& lhs, const auto& rhs) { return ((lhs->early < rhs->early) || (lhs->early == rhs->early && lhs->late < rhs->late)); }); - - // The thread pool has N threads running to process nodes. - // Thread pooling avoid the overhead of threads creation and deletion for - // each node execution. - ThreadPool pool; - - size_t latest = 0; - std::mutex schedulingMutex; - std::map<std::shared_ptr<StaticSchedulingElement>, std::atomic<bool>> finished; - - while (!staticSchedule.empty()) { - Log::debug("Step {}", latest); - - std::vector<std::shared_ptr<StaticSchedulingElement>> mustFinish; - - // Run all nodes that must be run at this step: latest (critical nodes) - for (size_t i = 0; i < staticSchedule.size(); ) { - auto runnable = staticSchedule[i]; - - if (runnable->late == latest) { - // Wait for potential preceding non-critical nodes to finish - while (true) { - bool ready = true; - for (auto elt : runnable->laterThan) { - ready = ready && finished.at(elt); - } - if (!ready) { - std::this_thread::yield(); - } - else { - break; - } - } - - // Add the critical node to the thread pool queue, to be run ASAP - finished[runnable] = false; - pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() { - const auto tStart = std::chrono::high_resolution_clock::now(); - node->forward(); - const auto tEnd = std::chrono::high_resolution_clock::now(); - finished = true; - { - std::unique_lock<std::mutex> lock(schedulingMutex); - mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd)); - } - }); - staticSchedule.erase(staticSchedule.begin() + i); - mustFinish.push_back(runnable); - - Log::debug(" run critical {}", namePtrTable.at(runnable->node)); - - // Ensure the following nodes cannot start earlier than next step - for (auto elt : runnable->earlierThan) { - if (elt->early < latest + 1) { - Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1); - elt->early = latest + 1; - AIDGE_INTERNAL_ASSERT(elt->early <= elt->late); - } - } - } - else if (runnable->early > latest + 1) { - // There cannot be more node that must be run at latest + 1 - // latest + 1 and not latest because early may have been updated - // for some elements to latest + 1 (above). - break; - } - else { - ++i; - } - } - - // If some threads are still available, run next early nodes - // These nodes are non-critical, meaning they can still be run at least - // in the next step - for (size_t i = 0; i < staticSchedule.size(); ) { - auto runnable = staticSchedule[i]; - if (!pool.busy() && runnable->early <= latest) { - // Check that potential preceding non-critical nodes are finished - bool ready = true; - for (auto elt : runnable->laterThan) { - ready = ready && finished.at(elt); - } - - if (ready) { - // All preceding nodes have finished, this node can be run. - // Add the node to the thread pool queue, to be run ASAP - finished[runnable] = false; - pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() { - const auto tStart = std::chrono::high_resolution_clock::now(); - node->forward(); - const auto tEnd = std::chrono::high_resolution_clock::now(); - finished = true; - { - std::unique_lock<std::mutex> lock(schedulingMutex); - mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd)); - } - }); - staticSchedule.erase(staticSchedule.begin() + i); - - Log::debug(" run {}", namePtrTable.at(runnable->node)); - - // Ensure the following nodes cannot start earlier than next step - for (auto elt : runnable->earlierThan) { - if (elt->early < latest + 1) { - Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1); - elt->early = latest + 1; - AIDGE_INTERNAL_ASSERT(elt->early <= elt->late); - } - } - } - else { - // The node cannot be run yet, because preceding nodes are - // still running, move to the next one in schedule - ++i; - } - } - else { - // Thread pool is already full or no more node can be run at - // this step (latest) - break; - } - } - - // Wait for all nodes that must finish at latest to be finished - // By scheduling construction, no other node can be started before all - // nodes at latest step are finished - while (true) { - bool ready = true; - for (auto elt : mustFinish) { - ready = ready && finished.at(elt); - } - if (!ready) { - std::this_thread::yield(); - } - else { - break; - } - } - - ++latest; - } - - ++mStaticScheduleStep; - if (mStaticScheduleStep == mStaticSchedule.size()) { - mStaticScheduleStep = 0; - } -} diff --git a/src/scheduler/SequentialScheduler.cpp b/src/scheduler/SequentialScheduler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f39c6120708a772d233c31d07ebd11fbe8d3d1a2 --- /dev/null +++ b/src/scheduler/SequentialScheduler.cpp @@ -0,0 +1,73 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/scheduler/SequentialScheduler.hpp" + +#include <chrono> +#include <memory> +#include <set> +#include <string> + +#include <fmt/ranges.h> +#include <fmt/color.h> + +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/Types.h" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/operator/Memorize.hpp" +#include "aidge/operator/MetaOperator.hpp" + +void Aidge::SequentialScheduler::forward(bool forwardDims, std::vector<std::shared_ptr<Aidge::Tensor>> data) { + // Collect all data input of the graph (that are producers) + if (!data.empty()){ + connectInputs(data); + } + + // Forward dims (if allowed) + if (forwardDims) {mGraphView->forwardDims(); } + + // Generate scheduling *only if empty* + // If scheduling was already generated (in one or several steps, i.e. one or + // several successive call to generateScheduling()), do not generate it twice + if (mStaticSchedule.empty()) { + this->generateScheduling(); + } + + // Sort static scheduling according to the policy + std::vector<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end()); + + if (mSchedulingPolicy == AsSoonAsPossible) { + std::stable_sort(staticSchedule.begin(), staticSchedule.end(), + [](const auto& lhs, const auto& rhs) { return (lhs->early < rhs->early); }); + } + else if (mSchedulingPolicy == AsLateAsPossible) { + std::stable_sort(staticSchedule.begin(), staticSchedule.end(), + [](const auto& lhs, const auto& rhs) { return (lhs->late < rhs->late); }); + } + + const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); + + for (const auto& runnable : staticSchedule) { + Log::debug("run: {}", namePtrTable.at(runnable->node)); + + const auto tStart = std::chrono::high_resolution_clock::now(); + runnable->node->forward(); + const auto tEnd = std::chrono::high_resolution_clock::now(); + mScheduling.push_back(SchedulingElement(runnable->node, tStart, tEnd)); + } + + ++mStaticScheduleStep; + if (mStaticScheduleStep == mStaticSchedule.size()) { + mStaticScheduleStep = 0; + } +} diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index 563fc19907e8800bbe6dec9e23242c6bb2c90c43..ab5fef1f6e861fe7dbd2d0095547b378209e74c0 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -25,7 +25,7 @@ #include "aidge/graph/OpArgs.hpp" #include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/Producer.hpp" -#include "aidge/scheduler/Scheduler.hpp" +#include "aidge/scheduler/SequentialScheduler.hpp" using namespace Aidge;