From cc2136e477bd890a2a6646434c090d02a7c2cfb3 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Fri, 5 Apr 2024 10:09:33 +0000 Subject: [PATCH] clean of scheduling files --- include/aidge/scheduler/MemoryManager.hpp | 14 +- include/aidge/scheduler/ParallelScheduler.hpp | 6 +- include/aidge/scheduler/Scheduler.hpp | 74 +++++---- .../aidge/scheduler/SequentialScheduler.hpp | 24 +-- include/aidge/scheduler/ThreadPool.hpp | 6 +- src/scheduler/Scheduler.cpp | 156 +++++++++--------- src/scheduler/SequentialScheduler.cpp | 4 +- unit_tests/scheduler/Test_Scheduler.cpp | 4 +- 8 files changed, 155 insertions(+), 133 deletions(-) diff --git a/include/aidge/scheduler/MemoryManager.hpp b/include/aidge/scheduler/MemoryManager.hpp index 9f718e8df..21d122b44 100644 --- a/include/aidge/scheduler/MemoryManager.hpp +++ b/include/aidge/scheduler/MemoryManager.hpp @@ -9,8 +9,8 @@ * ********************************************************************************/ -#ifndef AIDGE_MEMORY_MANAGER_H -#define AIDGE_MEMORY_MANAGER_H +#ifndef AIDGE_CORE_SCHEDULER_MEMORY_MANAGER_H +#define AIDGE_CORE_SCHEDULER_MEMORY_MANAGER_H #include <memory> #include <vector> @@ -75,12 +75,12 @@ public: count(count_) { assert(offset <= memSpace->size); - // The preceding assert should allow offset == memSpace->size (see + // The preceding assert should allow offset == memSpace->size (see // issue #63). This means immediate wrapping. // It appends if the final offset computed in reallocate() is at // the end of the previous memPlane and is also at the end of the // memSpace (in case for example of in-place memory op.). - // Instead of bringing the offset back to the beginning of the + // Instead of bringing the offset back to the beginning of the // memSpace, we stay attached to this offset in case the memSpace // grows when a new memPlane is added. @@ -128,7 +128,7 @@ public: // Limit is computed dynamically, as memSpace->size may increase after // the creation of this memory space. This is actually necessary to - // ensure that the memory wrapping works correctly, because when + // ensure that the memory wrapping works correctly, because when // computing the margin required for the wrapping, it is assumed that // the previous layer wrapping extends to the full memory space size. inline unsigned int getLimit() const { @@ -246,7 +246,7 @@ public: unsigned int stride = 0, unsigned int length = 1, unsigned int count = 1); - /// Generate a new MemoryPlane in an existing MemorySpace, associated to a + /// Generate a new MemoryPlane in an existing MemorySpace, associated to a /// Node unsigned int reallocate(std::shared_ptr<MemorySpace> memSpace, const std::shared_ptr<Node>& node, @@ -321,4 +321,4 @@ const char* const EnumStrings<Aidge::MemoryManager::OptimizeStrategy>::data[] "OptimizeMaxHoleMaxLifetimeFirst"}; } -#endif // AIDGE_MEMORY_MANAGER_H +#endif // AIDGE_CORE_SCHEDULER_MEMORY_MANAGER_H diff --git a/include/aidge/scheduler/ParallelScheduler.hpp b/include/aidge/scheduler/ParallelScheduler.hpp index d471c65ff..0b6f963d6 100644 --- a/include/aidge/scheduler/ParallelScheduler.hpp +++ b/include/aidge/scheduler/ParallelScheduler.hpp @@ -9,8 +9,8 @@ * ********************************************************************************/ -#ifndef AIDGE_PARALLELSCHEDULER_H_ -#define AIDGE_PARALLELSCHEDULER_H_ +#ifndef AIDGE_CORE_SCHEDULER_PARALLELSCHEDULER_H_ +#define AIDGE_CORE_SCHEDULER_PARALLELSCHEDULER_H_ #include <chrono> #include <memory> @@ -41,4 +41,4 @@ public: }; } // namespace Aidge -#endif /* AIDGE_PARALLELSCHEDULER_H_ */ +#endif /* AIDGE_CORE_SCHEDULER_PARALLELSCHEDULER_H_ */ diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 79eeefb2b..75ac90502 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -9,20 +9,20 @@ * ********************************************************************************/ -#ifndef AIDGE_SCHEDULER_H_ -#define AIDGE_SCHEDULER_H_ +#ifndef AIDGE_CORE_SCHEDULER_SCHEDULER_H_ +#define AIDGE_CORE_SCHEDULER_SCHEDULER_H_ +#include <cstddef> // std::size_t #include <chrono> +#include <map> #include <memory> #include <set> #include <string> #include <vector> -#include <map> - -#include "aidge/utils/Types.h" #include "aidge/data/Tensor.hpp" #include "aidge/scheduler/MemoryManager.hpp" +#include "aidge/utils/Types.h" namespace Aidge { class Node; @@ -33,17 +33,20 @@ protected: struct StaticSchedulingElement { StaticSchedulingElement( std::shared_ptr<Node> node_, - size_t early_ = static_cast<size_t>(-1), - size_t late_ = static_cast<size_t>(-1)) + std::size_t early_ = static_cast<std::size_t>(-1), + std::size_t late_ = static_cast<std::size_t>(-1)) : node(node_), early(early_), late(late_) {} std::shared_ptr<Node> node; - size_t early; - size_t late; + std::size_t early; + std::size_t late; std::vector<std::shared_ptr<StaticSchedulingElement>> earlierThan; std::vector<std::shared_ptr<StaticSchedulingElement>> laterThan; }; + /** + * @brief Node with its start/end execution time stored for later display. + */ struct SchedulingElement { SchedulingElement( std::shared_ptr<Node> node_, @@ -69,10 +72,22 @@ public: { // ctor }; - virtual ~Scheduler() = default; + virtual ~Scheduler() noexcept = default; + +public: /** - * Generate full static scheduling of the GraphView. + * @brief Return a vector of Node ordered by the order they are called by the scheduler. + * @return std::vector<std::shared_ptr<Node>> + */ + std::vector<std::shared_ptr<Node>> getStaticScheduling(std::size_t step = 0) const; + + inline std::shared_ptr<GraphView> graphView() const noexcept { + return mGraphView; + } + + /** + * @brief Generate full static scheduling of the GraphView. * For each node, an earliest and latest possible execution logical step * is specified. Nodes that may be scheduled at the same logical step have * no data dependency and can be run in parallel. @@ -110,18 +125,21 @@ public: */ void saveSchedulingDiagram(const std::string& fileName) const; + +protected: /** - * @brief Return a vector of Node ordered by the order they are called by the scheduler - * @return std::vector<std::shared_ptr<Node>> + * @brief Getter for the set of children Nodes of the given input Nodes. + * @param producers Set of Nodes for which we want to obtain the set of children Nodes. + * @return std::set<std::shared_ptr<Node>> Children Nodes. */ - std::vector<std::shared_ptr<Node>> getStaticScheduling(size_t step = 0) const; - inline std::shared_ptr<GraphView> getGraphView() const noexcept { - return mGraphView; - } + std::set<std::shared_ptr<Node>> getConsumers(const std::set<std::shared_ptr<Node>>& producers) const; + + Elts_t getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const; + + PriorProducersConsumers getPriorProducersConsumers(const std::shared_ptr<Node>& node) const; -protected: /** - * Generate an initial base scheduling for the GraphView. + * @brief Generate an initial base scheduling for the GraphView. * The scheduling is entirely sequential and garanteed to be valid w.r.t. * each node producer-consumer model. */ @@ -129,21 +147,15 @@ protected: /** * Fill-in early and late scheduling step from initial base scheduling. - * For each node, specifies the earliest and latest possible execution + * For each node, specifies the earliest and latest possible execution * logical step. */ void generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const; - /** - * @brief Set of layers receiving an input from currently processing layers - * - * @param producers Set of layers ready to run. - * @return std::set<std::shared_ptr<Node>> - */ - std::set<std::shared_ptr<Node>> getConsumers(const std::set<std::shared_ptr<Node>>& producers) const; - Elts_t getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const; - PriorProducersConsumers getPriorProducersConsumers(const std::shared_ptr<Node>& node) const; +private: + void summarizeConsumerState(const std::shared_ptr<Node>& consumer, const std::string& nodeName) const; +protected: /** @brief Shared ptr to the scheduled graph view */ std::shared_ptr<GraphView> mGraphView; /** @brief Shared ptr to the upper node containing the graph view */ @@ -152,9 +164,9 @@ protected: std::vector<SchedulingElement> mScheduling; /** @brief List of nodes ordered by their */ std::vector<std::vector<std::shared_ptr<StaticSchedulingElement>>> mStaticSchedule; - size_t mStaticScheduleStep = 0; + std::size_t mStaticScheduleStep = 0; mutable std::map<std::shared_ptr<Node>, PriorProducersConsumers> mPriorCache; }; } // namespace Aidge -#endif /* AIDGE_SCHEDULER_H_ */ +#endif /* AIDGE_CORE_SCHEDULER_SCHEDULER_H_ */ diff --git a/include/aidge/scheduler/SequentialScheduler.hpp b/include/aidge/scheduler/SequentialScheduler.hpp index be0a4a991..9cf0c2c18 100644 --- a/include/aidge/scheduler/SequentialScheduler.hpp +++ b/include/aidge/scheduler/SequentialScheduler.hpp @@ -9,16 +9,15 @@ * ********************************************************************************/ -#ifndef AIDGE_SEQUENTIALSCHEDULER_H_ -#define AIDGE_SEQUENTIALSCHEDULER_H_ +#ifndef AIDGE_CORE_SCHEDULER_SEQUENTIALSCHEDULER_H_ +#define AIDGE_CORE_SCHEDULER_SEQUENTIALSCHEDULER_H_ -#include <chrono> #include <memory> -#include <set> -#include <string> #include <vector> -#include <map> +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" #include "aidge/scheduler/Scheduler.hpp" namespace Aidge { @@ -27,23 +26,26 @@ namespace Aidge { */ class SequentialScheduler : public Scheduler { public: - enum SchedulingPolicy { + enum class SchedulingPolicy { Default, AsSoonAsPossible, AsLateAsPossible }; +public: SequentialScheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr) : Scheduler(graphView, upperNode), - mSchedulingPolicy(Default) + mSchedulingPolicy(SchedulingPolicy::Default) { // ctor }; + + ~SequentialScheduler() = default; + +public: inline void setSchedulingPolicy(SchedulingPolicy policy) { mSchedulingPolicy = policy; } - ~SequentialScheduler() = default; - /** * @brief Run the provided Computational Graph with a batch of data */ @@ -59,4 +61,4 @@ private: }; } // namespace Aidge -#endif /* AIDGE_SEQUENTIALSCHEDULER_H_ */ +#endif /* AIDGE_CORE_SCHEDULER_SEQUENTIALSCHEDULER_H_ */ diff --git a/include/aidge/scheduler/ThreadPool.hpp b/include/aidge/scheduler/ThreadPool.hpp index 5f2d9192d..e016ad4f3 100644 --- a/include/aidge/scheduler/ThreadPool.hpp +++ b/include/aidge/scheduler/ThreadPool.hpp @@ -9,8 +9,8 @@ * ********************************************************************************/ -#ifndef AIDGE_THREADPOOL_H_ -#define AIDGE_THREADPOOL_H_ +#ifndef AIDGE_CORE_SCHEDULER_THREADPOOL_H_ +#define AIDGE_CORE_SCHEDULER_THREADPOOL_H_ #include <thread> #include <mutex> @@ -39,4 +39,4 @@ private: }; } // namespace Aidge -#endif /* AIDGE_THREADPOOL_H_ */ +#endif /* AIDGE_CORE_SCHEDULER_THREADPOOL_H_ */ diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index b3b2d5e5b..ecf000ef3 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -11,21 +11,29 @@ #include "aidge/scheduler/Scheduler.hpp" +#include <algorithm> // std::find, std::find_if, std::max, std::min, std::replace, std::transform +#include <cassert> #include <chrono> +#include <cstddef> // std::size_t +#include <cstdio> // std::fclose, std::fopen +#include <iterator> // std::back_inserter, std::distance +#include <map> #include <memory> #include <set> #include <string> +#include <vector> -#include <fmt/ranges.h> +#include <fmt/core.h> #include <fmt/color.h> +#include <fmt/ranges.h> #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Node.hpp" -#include "aidge/operator/OperatorTensor.hpp" -#include "aidge/utils/Types.h" -#include "aidge/operator/Producer.hpp" #include "aidge/operator/Memorize.hpp" #include "aidge/operator/MetaOperator.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/Types.h" void Aidge::Scheduler::generateScheduling() { auto schedule = generateBaseScheduling(); @@ -34,29 +42,37 @@ void Aidge::Scheduler::generateScheduling() { } std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::Scheduler::generateBaseScheduling() const { - // 1) Setup initial consumers list: - // It is the list of input nodes - std::set<std::shared_ptr<Node>> consumers = mGraphView->inputNodes(); - // Plus the list of nodes inside the graph connected to an inner producer - std::set<std::shared_ptr<Node>> producers; - for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) { - if (nodePtr->type() == Producer_Op::Type) { - producers.insert(nodePtr); - } - } - const auto producersConsumers = getConsumers(producers); - consumers.insert(producersConsumers.begin(), producersConsumers.end()); + // 0) setup useful variables + // map associating each node with string "name (type#rank)" const std::map<std::shared_ptr<Node>, std::string> namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); - // Still consumers are consumers that were run by can still consume data. + // consumers that were run by but can still consume data. // They must be run AFTER the remaining consumer to ensure a non-greedy // producers-consumers model! std::set<std::shared_ptr<Node>> stillConsumers; std::vector<std::shared_ptr<StaticSchedulingElement>> schedule; + + // 1) Initialize consumers list: + // 1.1) List of the GraphView's input nodes + std::set<std::shared_ptr<Node>> consumers = mGraphView->inputNodes(); + + // 1.2) List of nodes inside the GraphView connected to an inner Producer + std::set<std::shared_ptr<Node>> producers; + for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) { + if (nodePtr->type() == Producer_Op::Type) { + for (const auto& child : nodePtr->getChildren()) { + // Do not schedule childs outside current graph! + if (mGraphView->inView(child)) { + consumers.insert(child); + } + } + } + } + do { // 2) From the current consumers list, check if any prior consumer node // is needed. A prior will generally be required for any node consuming @@ -69,8 +85,8 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S // in the new priorConsumers list. The initial consumer will become // again a consumer later, by construction. Log::debug("List of consumers with their priors:"); - std::set<std::shared_ptr<Node>> requiredProducers; - std::set<std::shared_ptr<Node>> priorConsumers; + std::set<std::shared_ptr<Node>> requiredProducers; // Priors of type Producer + std::set<std::shared_ptr<Node>> priorConsumers; // Priors of other type mPriorCache.clear(); for (const auto& consumer : consumers) { @@ -119,23 +135,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S std::set<std::shared_ptr<Node>> runnableConsumers; Log::debug("Updated list of consumers:"); for (const auto& consumer : consumers) { - Log::debug("\t- consumer: {}", fmt::styled(namePtrTable.at(consumer), fg(fmt::color::orange))); - - std::string crLog = "\t\tC/R:\t"; - for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { - crLog += fmt::format("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), - consumer->getOperator()->getNbRequiredData(inId)); - } - crLog += fmt::format("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), - consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1)); - Log::debug("{}", crLog); - - std::string pLog = "\t\tP:\t"; - for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) { - pLog += fmt::format("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); - } - pLog += fmt::format("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); - Log::debug("{}", pLog); + summarizeConsumerState(consumer, namePtrTable.at(consumer)); // debug print bool isRunnable = true; for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) { @@ -184,24 +184,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S // 7) Update consumers list Log::debug("Updating producer and consumer lists..."); for (const auto& consumer : runnableConsumers) { - Log::debug("\t- consumer: {}", fmt::styled(namePtrTable.at(consumer), fg(fmt::color::orange))); - - std::string crLog = "\t\tC/R:\t"; - for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { - crLog += fmt::format("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), - consumer->getOperator()->getNbRequiredData(inId)); - } - crLog += fmt::format("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), - consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1)); - Log::debug("{}", crLog); - - std::string pLog = "\t\tP:\t"; - for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) { - pLog += fmt::format("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); - } - pLog += fmt::format("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); - Log::debug("{}", pLog); - + summarizeConsumerState(consumer, namePtrTable.at(consumer)); // debug print // 7.1) If the current consumer has still data to consume, it will // be put back in the consumers list once the remaining consumers // have been exhausted. @@ -297,16 +280,37 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S return schedule; } + +void Aidge::Scheduler::summarizeConsumerState(const std::shared_ptr<Aidge::Node>& consumer, const std::string& nodeName) const { + Log::debug("\t- consumer: {}", fmt::styled(nodeName, fg(fmt::color::orange))); + std::string crLog = "\t\tC/R:\t"; + for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { + crLog += fmt::format("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), + consumer->getOperator()->getNbRequiredData(inId)); + } + crLog += fmt::format("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), + consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1)); + Log::debug("{}", crLog); + + std::string pLog = "\t\tP:\t"; + for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) { + pLog += fmt::format("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); + } + pLog += fmt::format("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); + Log::debug("{}", pLog); +} + + void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const { - size_t latest = 0; + std::size_t latest = 0; // Calculate early (logical) start - for (size_t elt = 0; elt < schedule.size(); ++elt) { + for (std::size_t elt = 0; elt < schedule.size(); ++elt) { const auto node = schedule[elt]->node; const auto itNode = std::find_if(schedule.rend() - elt, schedule.rend(), [node](const auto& v) { return (v->node == node); }); // Node can be run the earliest just after its childs were run the last time! - size_t early = 0; + std::size_t early = 0; if (itNode != schedule.rend()) { for (const auto& child : node->getChildren()) { // Find child node next scheduled position @@ -314,7 +318,7 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<std::shared_ptr<S [child](const auto& v) { return (v->node == child); }); AIDGE_INTERNAL_ASSERT(it != schedule.rend()); - const size_t step = std::distance(schedule.begin(), it.base()) - 1; + const std::size_t step = std::distance(schedule.begin(), it.base()) - 1; early = std::max(early, schedule[step]->early + 1); schedule[step]->earlierThan.push_back(schedule[elt]); } @@ -326,7 +330,7 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<std::shared_ptr<S const auto it = std::find_if(schedule.rend() - elt, schedule.rend(), [parent](const auto& v) { return (v->node == parent); }); if (it != schedule.rend()) { - const size_t step = std::distance(schedule.begin(), it.base()) - 1; + const std::size_t step = std::distance(schedule.begin(), it.base()) - 1; early = std::max(early, schedule[step]->early + 1); schedule[step]->earlierThan.push_back(schedule[elt]); } @@ -337,13 +341,13 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<std::shared_ptr<S } // Calculate late (logical) start - for (size_t elt = schedule.size(); elt-- != 0; ) { + for (std::size_t elt = schedule.size(); elt-- != 0; ) { const auto node = schedule[elt]->node; const auto itNode = std::find_if(schedule.begin() + elt + 1, schedule.end(), [node](const auto& v) { return (v->node == node); }); // Node can be run the latest just before its parents are run the next time! - size_t late = latest; + std::size_t late = latest; if (itNode != schedule.end()) { for (const auto& parent : node->getParents()) { // Find child node next scheduled position @@ -351,7 +355,7 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<std::shared_ptr<S [parent](const auto& v) { return (v->node == parent); }); AIDGE_INTERNAL_ASSERT(it != schedule.end()); - const size_t step = std::distance(schedule.begin(), it); + const std::size_t step = std::distance(schedule.begin(), it); late = std::min(late, schedule[step]->late - 1); schedule[step]->laterThan.push_back(schedule[elt]); } @@ -363,7 +367,7 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<std::shared_ptr<S const auto it = std::find_if(schedule.begin() + elt + 1, schedule.end(), [child](const auto& v) { return (v->node == child); }); if (it != schedule.end()) { - const size_t step = std::distance(schedule.begin(), it); + const std::size_t step = std::distance(schedule.begin(), it); late = std::min(late, schedule[step]->late - 1); schedule[step]->laterThan.push_back(schedule[elt]); } @@ -389,7 +393,7 @@ void Aidge::Scheduler::resetScheduling() { Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wrapAroundBuffer) const { MemoryManager memManager; - for (size_t step = 0; step < mStaticSchedule.size(); ++step) { + for (std::size_t step = 0; step < mStaticSchedule.size(); ++step) { for (const auto& node : getStaticScheduling(step)) { if (!incProducers && node->type() == Producer_Op::Type) { memManager.releaseDependencies(node); @@ -412,10 +416,10 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wr node->name(), node->type()); // By default, specifies a fully monolithic memory block - size_t size = requiredSize.data; - size_t stride = 0; - size_t length = 1; - size_t count = 1; + std::size_t size = requiredSize.data; + std::size_t stride = 0; + std::size_t length = 1; + std::size_t count = 1; if (op->getOutput(outputIdx) && op->getOutput(outputIdx)->dims().size() > 3) { // If it is possible, assume a NCHW layout @@ -428,8 +432,8 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wr // Check if wrap around buffer is possible for this node // (re-using previous node outputs memory for this node outputs). // => only if this node is the only child of its parent(s) - size_t wrapAroundSize = 0; - size_t wrapAroundExtra = 0; + std::size_t wrapAroundSize = 0; + std::size_t wrapAroundExtra = 0; wrapAroundMemPlane.push_back(nullptr); // Select the best parent among all allocable nodes for @@ -575,7 +579,7 @@ void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName) fmt::print(fp.get(), "\n"); } -std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getStaticScheduling(size_t step) const { +std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getStaticScheduling(std::size_t step) const { const auto& staticSchedule = mStaticSchedule.at(step); std::vector<std::shared_ptr<Node>> schedule; std::transform(staticSchedule.begin(), staticSchedule.end(), std::back_inserter(schedule), [](const auto& v) { return v->node; }); @@ -659,24 +663,26 @@ Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersCon if ((node->getOperator()->getNbConsumedData(inputIdx) + node->getOperator()->getNbRequiredData(inputIdx)) > parent.first->getOperator()->getNbProducedData(parent.second)) { + // the node needs more data than the current parent has provided yet if (!mGraphView->inView(parent.first)) { // Do not schedule prior outside the current graph! - return PriorProducersConsumers(); + // return PriorProducersConsumers(); // not scheduled + prior.priorConsumers.insert(node); } - if (parent.first->type() == Producer_Op::Type) { + else if (parent.first->type() == Producer_Op::Type) { prior.requiredProducers.insert(parent.first); prior.priorConsumers.insert(node); } else if (parent.first->type() == Memorize_Op::Type) { // Break cycles - return PriorProducersConsumers(); + return PriorProducersConsumers(); // not scheduled } else { const auto& parentPrior = getPriorProducersConsumers(parent.first); if (!parentPrior.isPrior) { - return PriorProducersConsumers(); + return PriorProducersConsumers(); // not scheduled } else { prior.requiredProducers.insert(parentPrior.requiredProducers.cbegin(), parentPrior.requiredProducers.cend()); diff --git a/src/scheduler/SequentialScheduler.cpp b/src/scheduler/SequentialScheduler.cpp index 1454cb2ba..801f46ffb 100644 --- a/src/scheduler/SequentialScheduler.cpp +++ b/src/scheduler/SequentialScheduler.cpp @@ -47,11 +47,11 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, std::vector<std::shar // Sort static scheduling according to the policy std::vector<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end()); - if (mSchedulingPolicy == AsSoonAsPossible) { + if (mSchedulingPolicy == SchedulingPolicy::AsSoonAsPossible) { std::stable_sort(staticSchedule.begin(), staticSchedule.end(), [](const auto& lhs, const auto& rhs) { return (lhs->early < rhs->early); }); } - else if (mSchedulingPolicy == AsLateAsPossible) { + else if (mSchedulingPolicy == SchedulingPolicy::AsLateAsPossible) { std::stable_sort(staticSchedule.begin(), staticSchedule.end(), [](const auto& lhs, const auto& rhs) { return (lhs->late < rhs->late); }); } diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index 7eb0290d9..e2c1a8fcb 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -27,7 +27,7 @@ #include "aidge/operator/Producer.hpp" #include "aidge/scheduler/SequentialScheduler.hpp" -using namespace Aidge; +namespace Aidge { TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { const size_t nbTests = 10; @@ -132,3 +132,5 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { } fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests); } + +} // namespace Aidge -- GitLab