Skip to content
Snippets Groups Projects
Commit 10a9e375 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Small refactoring of Scheduler

parent 1eee61f4
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!94Improved scheduling
Pipeline #40847 failed
...@@ -33,8 +33,8 @@ private: ...@@ -33,8 +33,8 @@ private:
struct StaticSchedulingElement { struct StaticSchedulingElement {
StaticSchedulingElement( StaticSchedulingElement(
std::shared_ptr<Node> node_, std::shared_ptr<Node> node_,
size_t early_, size_t early_ = static_cast<size_t>(-1),
size_t late_) size_t late_ = static_cast<size_t>(-1))
: node(node_), early(early_), late(late_) {} : node(node_), early(early_), late(late_) {}
std::shared_ptr<Node> node; std::shared_ptr<Node> node;
...@@ -69,8 +69,17 @@ public: ...@@ -69,8 +69,17 @@ public:
}; };
~SequentialScheduler() = default; ~SequentialScheduler() = default;
void generateScheduling(bool verbose = false); /**
std::vector<StaticSchedulingElement> generateEarlyLateScheduling() const; * Generate full static scheduling of the GraphView.
* For each node, an earliest and latest possible execution logical step
* is specified. Nodes that may be scheduled at the same logical step have
* no data dependency and can be run in parallel.
*/
void generateScheduling();
/**
* Reset all scheduling and associated nodes producer consumer.
*/
void resetScheduling(); void resetScheduling();
/** /**
...@@ -92,26 +101,42 @@ public: ...@@ -92,26 +101,42 @@ public:
*/ */
void forward(bool forwardDims = true, bool verbose = false, std::vector<std::shared_ptr<Aidge::Tensor>> data = {}); void forward(bool forwardDims = true, bool verbose = false, std::vector<std::shared_ptr<Aidge::Tensor>> data = {});
/**
* @brief Save in a Markdown file the static scheduling with early and late relative order for the nodes.
* @param fileName Name of the generated file.
*/
void saveStaticSchedulingDiagram(const std::string& fileName) const;
/** /**
* @brief Save in a Markdown file the order of layers execution. * @brief Save in a Markdown file the order of layers execution.
* @param fileName Name of the generated file. * @param fileName Name of the generated file.
*/ */
void saveSchedulingDiagram(const std::string& fileName) const; void saveSchedulingDiagram(const std::string& fileName) const;
void saveStaticSchedulingDiagram(const std::string& fileName, const std::vector<StaticSchedulingElement>& scheduling) const;
/** /**
* @brief Return a vector of Node ordered by the order they are called by the scheduler * @brief Return a vector of Node ordered by the order they are called by the scheduler
* @return std::vector<std::shared_ptr<Node>> * @return std::vector<std::shared_ptr<Node>>
*/ */
inline std::vector<std::shared_ptr<Node>> getStaticScheduling(size_t step = 0) const noexcept { std::vector<std::shared_ptr<Node>> getStaticScheduling(size_t step = 0) const;
return mStaticSchedule.at(step);
}
inline std::shared_ptr<GraphView> getGraphView() const noexcept { inline std::shared_ptr<GraphView> getGraphView() const noexcept {
return mGraphView; return mGraphView;
} }
private: private:
/**
* Generate an initial base scheduling for the GraphView.
* The scheduling is entirely sequential and garanteed to be valid w.r.t.
* each node producer-consumer model.
*/
std::vector<StaticSchedulingElement> generateBaseScheduling() const;
/**
* Fill-in early and late scheduling step from initial base scheduling.
* For each node, specifies the earliest and latest possible execution
* logical step.
*/
void generateEarlyLateScheduling(std::vector<StaticSchedulingElement>& schedule) const;
/** /**
* @brief Set of layers receiving an input from currently processing layers * @brief Set of layers receiving an input from currently processing layers
* *
...@@ -129,7 +154,7 @@ private: ...@@ -129,7 +154,7 @@ private:
/** @brief List of SchedulingElement (i.e: Nodes with their computation time) */ /** @brief List of SchedulingElement (i.e: Nodes with their computation time) */
std::vector<SchedulingElement> mScheduling; std::vector<SchedulingElement> mScheduling;
/** @brief List of nodes ordered by their */ /** @brief List of nodes ordered by their */
std::vector<std::vector<std::shared_ptr<Node>>> mStaticSchedule; std::vector<std::vector<StaticSchedulingElement>> mStaticSchedule;
size_t mStaticScheduleStep = 0; size_t mStaticScheduleStep = 0;
}; };
} // namespace Aidge } // namespace Aidge
......
...@@ -23,7 +23,7 @@ void init_Scheduler(py::module& m){ ...@@ -23,7 +23,7 @@ void init_Scheduler(py::module& m){
.def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("verbose")=false, py::arg("data")=std::vector<Tensor>()) .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("verbose")=false, py::arg("data")=std::vector<Tensor>())
.def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name")) .def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name"))
.def("resetScheduling", &SequentialScheduler::resetScheduling) .def("resetScheduling", &SequentialScheduler::resetScheduling)
.def("generate_scheduling", &SequentialScheduler::generateScheduling, py::arg("verbose")=false) .def("generate_scheduling", &SequentialScheduler::generateScheduling)
.def("get_static_scheduling", &SequentialScheduler::getStaticScheduling, py::arg("step") = 0) .def("get_static_scheduling", &SequentialScheduler::getStaticScheduling, py::arg("step") = 0)
; ;
} }
......
This diff is collapsed.
...@@ -58,7 +58,7 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { ...@@ -58,7 +58,7 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") {
g1->forwardDims(); g1->forwardDims();
auto scheduler = SequentialScheduler(g1); auto scheduler = SequentialScheduler(g1);
scheduler.generateScheduling(true); scheduler.generateScheduling();
const auto sch = scheduler.getStaticScheduling(); const auto sch = scheduler.getStaticScheduling();
const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})"); const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment