diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 792d73693be0780f2e938d828b0f29889216631b..981920ea1b010edd01865bbd702a601a941941d5 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -28,8 +28,29 @@ namespace Aidge { class Node; class GraphView; + +/** + * @class Scheduler + * @brief Generate and manage the execution schedule order of nodes in a graph. + * It provides functionality for static scheduling, memory + * management, and visualization of the scheduling process. + * + * Key features: + * - Static scheduling generation with early and late execution times + * - Memory layout generation for scheduled nodes + * - Input tensor connection to graph nodes + * - Scheduling visualization through diagram generation + * + * @see GraphView + * @see Node + * @see MemoryManager + */ class Scheduler { protected: + /** + * @struct StaticSchedulingElement + * @brief Represents a node in the static schedule. + */ struct StaticSchedulingElement { StaticSchedulingElement( std::shared_ptr<Node> node_, @@ -37,15 +58,17 @@ protected: std::size_t late_ = static_cast<std::size_t>(-1)) : node(node_), early(early_), late(late_) {} - std::shared_ptr<Node> node; - std::size_t early; - std::size_t late; - std::vector<std::shared_ptr<StaticSchedulingElement>> earlierThan; - std::vector<std::shared_ptr<StaticSchedulingElement>> laterThan; + std::shared_ptr<Node> node; /** Scheduled `Node` */ + std::size_t early; /** Earliest possible execution time */ + std::size_t late; /** Latest possible execution time */ + std::vector<std::shared_ptr<StaticSchedulingElement>> earlierThan; /** Nodes that must be executed earlier */ + std::vector<std::shared_ptr<StaticSchedulingElement>> laterThan; /** Nodes that must be executed later */ }; /** - * @brief Node with its start/end execution time stored for later display. + * @struct SchedulingElement + * @brief Represent a `Node` with its actual execution times. + * @details Start and end times are stored for later display. */ struct SchedulingElement { SchedulingElement( @@ -54,21 +77,32 @@ protected: std::chrono::time_point<std::chrono::high_resolution_clock> end_) : node(node_), start(start_), end(end_) {} ~SchedulingElement() noexcept = default; - std::shared_ptr<Node> node; - std::chrono::time_point<std::chrono::high_resolution_clock> start; - std::chrono::time_point<std::chrono::high_resolution_clock> end; + std::shared_ptr<Node> node; /** Executed `Node` */ + std::chrono::time_point<std::chrono::high_resolution_clock> start; /** Actual start time of execution */ + std::chrono::time_point<std::chrono::high_resolution_clock> end; /** Actual end time of execution */ }; public: + /** + * @struct PriorProducersConsumers + * @brief Manages producer-consumer relationships for nodes. + */ struct PriorProducersConsumers { PriorProducersConsumers(); PriorProducersConsumers(const PriorProducersConsumers&); ~PriorProducersConsumers() noexcept; - bool isPrior = false; - std::set<std::shared_ptr<Aidge::Node>> requiredProducers; - std::set<std::shared_ptr<Aidge::Node>> priorConsumers; + bool isPrior = false; /** Indicates if this Node is a prior to another Node */ + std::set<std::shared_ptr<Aidge::Node>> requiredProducers; /** Set of required producer nodes */ + std::set<std::shared_ptr<Aidge::Node>> priorConsumers; /** Set of required prior consumer nodes */ }; public: + Scheduler() = delete; + + /** + * @brief Constructor for the Scheduler class. + * @param graphView Shared pointer to the GraphView to be scheduled. + * @param upperNode Shared pointer to the upper node of the GraphView (optional). + */ Scheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr) : mGraphView(graphView), mUpperNode(upperNode) @@ -80,11 +114,16 @@ public: public: /** - * @brief Return a vector of Node ordered by the order they are called by the scheduler. - * @return std::vector<std::shared_ptr<Node>> + * @brief Get the static scheduling order of nodes. + * @param step The step of the static schedule to retrieve (default is 0). + * @return Vector of shared pointers to Nodes in their scheduled order. */ std::vector<std::shared_ptr<Node>> getStaticScheduling(std::size_t step = 0) const; + /** + * @brief Get the GraphView associated with this Scheduler. + * @return Shared pointer to the GraphView. + */ inline std::shared_ptr<GraphView> graphView() const noexcept { return mGraphView; } @@ -110,20 +149,23 @@ public: MemoryManager generateMemory(bool incProducers = false, bool wrapAroundBuffer = false) const; /** - * @brief Place the data tensors inside in the data input tensor of the graphView. In case of multiple data input tensors, they are mapped to producers in the order given by the graph. + * @brief Connect input tensors to the data input of the GraphView. + * In case of multiple data input tensors, they are mapped to producers in + * the order given by the graph. * * @param data data input tensors */ void connectInputs(const std::vector<std::shared_ptr<Aidge::Tensor>>& data); /** - * @brief Save in a Markdown file the static scheduling with early and late relative order for the nodes. - * @param fileName Name of the generated file. + * @brief Save the static scheduling diagram, with early and late relative + * order of execution for the nodes, to a file in Mermaid format. + * @param fileName Name of the file to save the diagram (without extension). */ void saveStaticSchedulingDiagram(const std::string& fileName) const; /** - * @brief Save in a Markdown file the order of layers execution. + * @brief Save in a Mermaid file the order of layers execution. * @param fileName Name of the generated file. */ void saveSchedulingDiagram(const std::string& fileName) const; @@ -139,29 +181,48 @@ protected: Elts_t getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const; + /** + * @brief Get the prior producers and consumers for a node. + * @param node Shared pointer to the Node. + * @return PriorProducersConsumers object containing prior information. + */ PriorProducersConsumers getPriorProducersConsumers(const std::shared_ptr<Node>& node) const; /** * @brief Generate an initial base scheduling for the GraphView. * The scheduling is entirely sequential and garanteed to be valid w.r.t. * each node producer-consumer model. + * @return Vector of shared pointers to `StaticSchedulingElement` representing the base schedule. */ std::vector<std::shared_ptr<StaticSchedulingElement>> generateBaseScheduling() const; /** - * Fill-in early and late scheduling step from initial base scheduling. - * For each node, specifies the earliest and latest possible execution - * logical step. - */ + * @brief Calculates early and late execution times for each node in an initial base scheduling. + * + * This method performs two passes over the schedule: + * 1. Forward pass: Calculates the earliest possible execution time for each node + * 2. Backward pass: Calculates the latest possible execution time for each node + * + * It also establishes 'earlierThan' and 'laterThan' relationships between nodes. + * + * @param schedule Vector of shared pointers to StaticSchedulingElements to be processed + */ void generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const; private: + /** + * @brief Summarize the consumer state of a node for debugging purposes. + * @param consumer Shared pointer to the consumer Node. + * @param nodeName Name of the node. + * @details Provide the amount of data consumed and required for each input + * and the amount of data produced for each output. + */ void summarizeConsumerState(const std::shared_ptr<Node>& consumer, const std::string& nodeName) const; protected: - /** @brief Shared ptr to the scheduled graph view */ + /** @brief Shared pointer to the scheduled GraphView */ std::shared_ptr<GraphView> mGraphView; - /** @brief Shared ptr to the upper node containing the graph view */ + /** @brief Weak pointer to the upper node containing the graph view */ std::weak_ptr<Node> mUpperNode; /** @brief List of SchedulingElement (i.e: Nodes with their computation time) */ std::vector<SchedulingElement> mScheduling;