Skip to content
Snippets Groups Projects
Commit 577e654a authored by Maxence Naud's avatar Maxence Naud
Browse files

Fix - v0.3.1

Merge branch 'hotfix' into 'main'

See merge request !229
parents 7f2701ff 751191b1
No related branches found
No related tags found
3 merge requests!280Fix Scheduler::StaticSchedulingElement shared_ptr circular reference,!248Draft: [Add] MulPTQ and ScalingMeta MetaOperators,!229Version 0.3.1
Pipeline #59492 passed
......@@ -28,8 +28,29 @@ namespace Aidge {
class Node;
class GraphView;
/**
* @class Scheduler
* @brief Generate and manage the execution schedule order of nodes in a graph.
* It provides functionality for static scheduling, memory
* management, and visualization of the scheduling process.
*
* Key features:
* - Static scheduling generation with early and late execution times
* - Memory layout generation for scheduled nodes
* - Input tensor connection to graph nodes
* - Scheduling visualization through diagram generation
*
* @see GraphView
* @see Node
* @see MemoryManager
*/
class Scheduler {
protected:
/**
* @struct StaticSchedulingElement
* @brief Represents a node in the static schedule.
*/
struct StaticSchedulingElement {
StaticSchedulingElement(
std::shared_ptr<Node> node_,
......@@ -37,15 +58,17 @@ protected:
std::size_t late_ = static_cast<std::size_t>(-1))
: node(node_), early(early_), late(late_) {}
std::shared_ptr<Node> node;
std::size_t early;
std::size_t late;
std::vector<std::shared_ptr<StaticSchedulingElement>> earlierThan;
std::vector<std::shared_ptr<StaticSchedulingElement>> laterThan;
std::shared_ptr<Node> node; /** Scheduled `Node` */
std::size_t early; /** Earliest possible execution time */
std::size_t late; /** Latest possible execution time */
std::vector<StaticSchedulingElement*> earlierThan; /** Nodes that must be executed earlier */
std::vector<StaticSchedulingElement*> laterThan; /** Nodes that must be executed later */
};
/**
* @brief Node with its start/end execution time stored for later display.
* @struct SchedulingElement
* @brief Represent a `Node` with its actual execution times.
* @details Start and end times are stored for later display.
*/
struct SchedulingElement {
SchedulingElement(
......@@ -54,21 +77,32 @@ protected:
std::chrono::time_point<std::chrono::high_resolution_clock> end_)
: node(node_), start(start_), end(end_) {}
~SchedulingElement() noexcept = default;
std::shared_ptr<Node> node;
std::chrono::time_point<std::chrono::high_resolution_clock> start;
std::chrono::time_point<std::chrono::high_resolution_clock> end;
std::shared_ptr<Node> node; /** Executed `Node` */
std::chrono::time_point<std::chrono::high_resolution_clock> start; /** Actual start time of execution */
std::chrono::time_point<std::chrono::high_resolution_clock> end; /** Actual end time of execution */
};
public:
/**
* @struct PriorProducersConsumers
* @brief Manages producer-consumer relationships for nodes.
*/
struct PriorProducersConsumers {
PriorProducersConsumers();
PriorProducersConsumers(const PriorProducersConsumers&);
~PriorProducersConsumers() noexcept;
bool isPrior = false;
std::set<std::shared_ptr<Aidge::Node>> requiredProducers;
std::set<std::shared_ptr<Aidge::Node>> priorConsumers;
bool isPrior = false; /** Indicates if this Node is a prior to another Node */
std::set<std::shared_ptr<Aidge::Node>> requiredProducers; /** Set of required producer nodes */
std::set<std::shared_ptr<Aidge::Node>> priorConsumers; /** Set of required prior consumer nodes */
};
public:
Scheduler() = delete;
/**
* @brief Constructor for the Scheduler class.
* @param graphView Shared pointer to the GraphView to be scheduled.
* @param upperNode Shared pointer to the upper node of the GraphView (optional).
*/
Scheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr)
: mGraphView(graphView),
mUpperNode(upperNode)
......@@ -76,15 +110,20 @@ public:
// ctor
};
virtual ~Scheduler() noexcept;
virtual ~Scheduler();
public:
/**
* @brief Return a vector of Node ordered by the order they are called by the scheduler.
* @return std::vector<std::shared_ptr<Node>>
* @brief Get the static scheduling order of nodes.
* @param step The step of the static schedule to retrieve (default is 0).
* @return Vector of shared pointers to Nodes in their scheduled order.
*/
std::vector<std::shared_ptr<Node>> getStaticScheduling(std::size_t step = 0) const;
/**
* @brief Get the GraphView associated with this Scheduler.
* @return Shared pointer to the GraphView.
*/
inline std::shared_ptr<GraphView> graphView() const noexcept {
return mGraphView;
}
......@@ -110,20 +149,23 @@ public:
MemoryManager generateMemory(bool incProducers = false, bool wrapAroundBuffer = false) const;
/**
* @brief Place the data tensors inside in the data input tensor of the graphView. In case of multiple data input tensors, they are mapped to producers in the order given by the graph.
* @brief Connect input tensors to the data input of the GraphView.
* In case of multiple data input tensors, they are mapped to producers in
* the order given by the graph.
*
* @param data data input tensors
*/
void connectInputs(const std::vector<std::shared_ptr<Aidge::Tensor>>& data);
/**
* @brief Save in a Markdown file the static scheduling with early and late relative order for the nodes.
* @param fileName Name of the generated file.
* @brief Save the static scheduling diagram, with early and late relative
* order of execution for the nodes, to a file in Mermaid format.
* @param fileName Name of the file to save the diagram (without extension).
*/
void saveStaticSchedulingDiagram(const std::string& fileName) const;
/**
* @brief Save in a Markdown file the order of layers execution.
* @brief Save in a Mermaid file the order of layers execution.
* @param fileName Name of the generated file.
*/
void saveSchedulingDiagram(const std::string& fileName) const;
......@@ -139,34 +181,53 @@ protected:
Elts_t getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const;
/**
* @brief Get the prior producers and consumers for a node.
* @param node Shared pointer to the Node.
* @return PriorProducersConsumers object containing prior information.
*/
PriorProducersConsumers getPriorProducersConsumers(const std::shared_ptr<Node>& node) const;
/**
* @brief Generate an initial base scheduling for the GraphView.
* The scheduling is entirely sequential and garanteed to be valid w.r.t.
* each node producer-consumer model.
* @return Vector of pointers to `StaticSchedulingElement` representing the base schedule.
*/
std::vector<std::shared_ptr<StaticSchedulingElement>> generateBaseScheduling() const;
std::vector<StaticSchedulingElement*> generateBaseScheduling() const;
/**
* Fill-in early and late scheduling step from initial base scheduling.
* For each node, specifies the earliest and latest possible execution
* logical step.
*/
void generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const;
* @brief Calculates early and late execution times for each node in an initial base scheduling.
*
* This method performs two passes over the schedule:
* 1. Forward pass: Calculates the earliest possible execution time for each node
* 2. Backward pass: Calculates the latest possible execution time for each node
*
* It also establishes 'earlierThan' and 'laterThan' relationships between nodes.
*
* @param schedule Vector of shared pointers to StaticSchedulingElements to be processed
*/
void generateEarlyLateScheduling(std::vector<StaticSchedulingElement*>& schedule) const;
private:
/**
* @brief Summarize the consumer state of a node for debugging purposes.
* @param consumer Shared pointer to the consumer Node.
* @param nodeName Name of the node.
* @details Provide the amount of data consumed and required for each input
* and the amount of data produced for each output.
*/
void summarizeConsumerState(const std::shared_ptr<Node>& consumer, const std::string& nodeName) const;
protected:
/** @brief Shared ptr to the scheduled graph view */
/** @brief Shared pointer to the scheduled GraphView */
std::shared_ptr<GraphView> mGraphView;
/** @brief Shared ptr to the upper node containing the graph view */
/** @brief Weak pointer to the upper node containing the graph view */
std::weak_ptr<Node> mUpperNode;
/** @brief List of SchedulingElement (i.e: Nodes with their computation time) */
std::vector<SchedulingElement> mScheduling;
/** @brief List of nodes ordered by their */
std::vector<std::vector<std::shared_ptr<StaticSchedulingElement>>> mStaticSchedule;
std::vector<std::vector<StaticSchedulingElement*>> mStaticSchedule;
std::size_t mStaticScheduleStep = 0;
mutable std::map<std::shared_ptr<Node>, PriorProducersConsumers> mPriorCache;
};
......
......@@ -48,7 +48,7 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std::
// Sort static scheduling, the order will be the prefered threads scheduling
// order for non critical nodes
std::deque<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end());
std::deque<StaticSchedulingElement*> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end());
std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
[](const auto& lhs, const auto& rhs) { return ((lhs->early < rhs->early) || (lhs->early == rhs->early && lhs->late < rhs->late)); });
......@@ -59,12 +59,12 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std::
size_t latest = 0;
std::mutex schedulingMutex;
std::map<std::shared_ptr<StaticSchedulingElement>, std::atomic<bool>> finished;
std::map<StaticSchedulingElement*, std::atomic<bool>> finished;
while (!staticSchedule.empty()) {
Log::debug("Step {}", latest);
std::vector<std::shared_ptr<StaticSchedulingElement>> mustFinish;
std::vector<StaticSchedulingElement*> mustFinish;
// Run all nodes that must be run at this step: latest (critical nodes)
for (size_t i = 0; i < staticSchedule.size(); ) {
......@@ -188,7 +188,7 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std::
}
// Wait for all nodes that must finish at latest to be finished
// By scheduling construction, no other node can be started before all
// By scheduling construction, no other node can be started before all
// nodes at latest step are finished
while (true) {
bool ready = true;
......
......@@ -37,7 +37,14 @@
#include "aidge/utils/Types.h"
Aidge::Scheduler::~Scheduler() noexcept = default;
Aidge::Scheduler::~Scheduler() {
for (auto& staticScheduleVec : mStaticSchedule) {
for (auto& staticScheduleElt : staticScheduleVec) {
delete staticScheduleElt;
}
staticScheduleVec.clear();
}
}
Aidge::Scheduler::PriorProducersConsumers::PriorProducersConsumers() = default;
Aidge::Scheduler::PriorProducersConsumers::PriorProducersConsumers(const PriorProducersConsumers&) = default;
Aidge::Scheduler::PriorProducersConsumers::~PriorProducersConsumers() noexcept = default;
......@@ -48,7 +55,7 @@ void Aidge::Scheduler::generateScheduling() {
mStaticSchedule.push_back(schedule);
}
std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::Scheduler::generateBaseScheduling() const {
std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::generateBaseScheduling() const {
// 0) setup useful variables
// map associating each node with string "name (type#rank)"
......@@ -60,7 +67,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S
// producers-consumers model!
std::set<std::shared_ptr<Node>> stillConsumers;
std::vector<std::shared_ptr<StaticSchedulingElement>> schedule;
std::vector<StaticSchedulingElement*> schedule;
// 1) Initialize consumers list:
......@@ -131,7 +138,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S
// Producers are special nodes that generate data on demand.
for (const auto& requiredProducer : requiredProducers) {
requiredProducer->getOperator()->updateConsummerProducer();
schedule.push_back(std::make_shared<StaticSchedulingElement>(requiredProducer));
schedule.push_back(new StaticSchedulingElement(requiredProducer));
}
// 5) Find runnable consumers.
......@@ -185,7 +192,7 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S
for (const auto& runnable : runnableConsumers) {
Log::debug("Runnable: {}", namePtrTable.at(runnable));
runnable->getOperator()->updateConsummerProducer();
schedule.push_back(std::make_shared<StaticSchedulingElement>(runnable));
schedule.push_back(new StaticSchedulingElement(runnable));
}
// 7) Update consumers list
......@@ -317,7 +324,7 @@ void Aidge::Scheduler::summarizeConsumerState(const std::shared_ptr<Aidge::Node>
}
void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const {
void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingElement*>& schedule) const {
std::size_t latest = 0;
// Calculate early (logical) start
for (std::size_t elt = 0; elt < schedule.size(); ++elt) {
......@@ -397,15 +404,20 @@ void Aidge::Scheduler::resetScheduling() {
for (auto node : mGraphView->getNodes()) {
node->getOperator()->resetConsummerProducer();
}
for (auto& staticScheduleVec : mStaticSchedule) {
for (auto& staticScheduleElt : staticScheduleVec) {
delete staticScheduleElt;
}
staticScheduleVec.clear();
}
mStaticSchedule.clear();
mStaticScheduleStep = 0;
mScheduling.clear();
}
/**
* This version is a simplified version without special handling of concatenation.
*/
* @warning This version is a simplified version without special handling of concatenation.
*/
Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wrapAroundBuffer) const {
MemoryManager memManager;
......@@ -676,8 +688,8 @@ Aidge::Elts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>&
return Elts_t::NoneElts();
}
Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersConsumers(
const std::shared_ptr<Node>& node) const
Aidge::Scheduler::PriorProducersConsumers
Aidge::Scheduler::getPriorProducersConsumers(const std::shared_ptr<Node>& node) const
{
const auto priorCache = mPriorCache.find(node);
if (priorCache != mPriorCache.end()) {
......@@ -714,6 +726,7 @@ Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersCon
const auto& parentPrior = getPriorProducersConsumers(parent.first);
if (!parentPrior.isPrior) {
// only happens in case of cyclic graphs
return PriorProducersConsumers(); // not scheduled
}
else {
......
......@@ -45,7 +45,7 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std
}
// Sort static scheduling according to the policy
std::vector<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end());
std::vector<StaticSchedulingElement*> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end());
if (mSchedulingPolicy == SchedulingPolicy::AsSoonAsPossible) {
std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
......
......@@ -78,6 +78,7 @@ TEST_CASE("[core/operator] ConstantOfShape_Op(forwardDims)",
for (DimSize_t i = 0; i < op->getOutput(0)->nbDims(); ++i) {
CHECK(array_in[i] == op->getOutput(0)->dims().at(i));
}
delete[] array_in;
}
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment