diff --git a/include/aidge/scheduler/SequentialScheduler.hpp b/include/aidge/scheduler/SequentialScheduler.hpp index b95a298d6c354decebaa49f99744f19d0eab0572..71d92a35b7d1890a6a355342e18cd1f9477c1e1a 100644 --- a/include/aidge/scheduler/SequentialScheduler.hpp +++ b/include/aidge/scheduler/SequentialScheduler.hpp @@ -41,16 +41,25 @@ public: } /** - * Generate the memory layout for the current static scheduling. + * @brief Get the static scheduling sequential order of nodes following the + * current scheduling policy. + * @param step The step of the static schedule to retrieve (default is 0). + * @return Vector of shared pointers to Nodes in their scheduled order. + */ + std::vector<std::shared_ptr<Node>> getSequentialStaticScheduling(std::size_t step = 0) const; + + /** + * Generate the memory layout for the static scheduling following the + * current scheduling policy. * @param incProducers If true, include the producers in the memory layout. * @param wrapAroundBuffer If true, allow wrapping in memory planes. */ MemoryManager generateMemory(bool incProducers = false, bool wrapAroundBuffer = false) const; /** - * Generate the memory layout for the current static scheduling, with auto- - * concatenation: the Concat operator is replaced by direct allocation - * when possible. + * Generate the memory layout for the static scheduling following the + * current scheduling policy, with auto-concatenation: the Concat operator + * is replaced by direct allocation when possible. * @param incProducers If true, include the producers in the memory layout. * @param wrapAroundBuffer If true, allow wrapping in memory planes. */ diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp index 62bb7e0af1335b9e8d1004e6e8e99a425d3d6de1..dbe89436f4af246883a04d2a1e6a87022945a5a8 100644 --- a/python_binding/scheduler/pybind_Scheduler.cpp +++ b/python_binding/scheduler/pybind_Scheduler.cpp @@ -59,6 +59,7 @@ void init_Scheduler(py::module& m){ py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>, Scheduler>(m, "SequentialScheduler") .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) + .def("get_sequential_static_scheduling", &SequentialScheduler::getSequentialStaticScheduling, py::arg("step") = 0) .def("generate_memory", &SequentialScheduler::generateMemory, py::arg("inc_producers") = false, py::arg("wrap_around_buffer") = false) .def("generate_memory_auto_concat", &SequentialScheduler::generateMemoryAutoConcat, py::arg("inc_producers") = false, py::arg("wrap_around_buffer") = false) .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("data")=std::vector<Tensor>()) diff --git a/src/scheduler/SequentialScheduler.cpp b/src/scheduler/SequentialScheduler.cpp index 6a52882f65146a30c63d31dbffbe677506b9d33e..c89bf393ea12b9b10d6a1e0812c9ea0c5b753644 100644 --- a/src/scheduler/SequentialScheduler.cpp +++ b/src/scheduler/SequentialScheduler.cpp @@ -29,6 +29,10 @@ #include "aidge/operator/MetaOperator.hpp" #include "aidge/recipes/GraphViewHelper.hpp" +std::vector<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getSequentialStaticScheduling(std::size_t step) const { + return Scheduler::getSequentialStaticScheduling(step, mSchedulingPolicy); +} + /** * @warning This version is a simplified version without special handling of concatenation. */ @@ -36,7 +40,7 @@ Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducer MemoryManager memManager; for (std::size_t step = 0; step < mStaticSchedule.size(); ++step) { - for (const auto& node : getSequentialStaticScheduling(step, mSchedulingPolicy)) { + for (const auto& node : getSequentialStaticScheduling(step)) { if (!incProducers && node->type() == Producer_Op::Type) { memManager.releaseDependencies(node); continue; @@ -159,7 +163,7 @@ Aidge::MemoryManager Aidge::SequentialScheduler::generateMemoryAutoConcat(bool i // have already been memory mapped! // But in any case, the policy must remain the same than the scheduler // export policy. - for (const auto& node : getSequentialStaticScheduling(step, mSchedulingPolicy)) { + for (const auto& node : getSequentialStaticScheduling(step)) { if (!incProducers && node->type() == Producer_Op::Type) { memManager.releaseDependencies(node); continue; @@ -399,7 +403,7 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std } // Sort static scheduling according to the policy - const auto nodes = getSequentialStaticScheduling(mStaticScheduleStep, mSchedulingPolicy); + const auto nodes = getSequentialStaticScheduling(mStaticScheduleStep); const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); for (const auto& runnable : nodes) { @@ -431,7 +435,7 @@ void Aidge::SequentialScheduler::backward() { } // map of node <-> info to display with verbose - const auto nodes = getSequentialStaticScheduling(mStaticScheduleStep, mSchedulingPolicy); + const auto nodes = getSequentialStaticScheduling(mStaticScheduleStep); const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); // run scheduled operators in reverse order