From ae277f87ddea58dbd388eca9d85d3e027c8c217a Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Tue, 4 Mar 2025 22:08:58 +0100
Subject: [PATCH] Added getSequentialStaticScheduling() to SequentialScheduler

---
 include/aidge/scheduler/SequentialScheduler.hpp | 17 +++++++++++++----
 python_binding/scheduler/pybind_Scheduler.cpp   |  1 +
 src/scheduler/SequentialScheduler.cpp           | 12 ++++++++----
 3 files changed, 22 insertions(+), 8 deletions(-)

diff --git a/include/aidge/scheduler/SequentialScheduler.hpp b/include/aidge/scheduler/SequentialScheduler.hpp
index b95a298d6..71d92a35b 100644
--- a/include/aidge/scheduler/SequentialScheduler.hpp
+++ b/include/aidge/scheduler/SequentialScheduler.hpp
@@ -41,16 +41,25 @@ public:
     }
 
     /**
-     * Generate the memory layout for the current static scheduling.
+     * @brief Get the static scheduling sequential order of nodes following the
+     * current scheduling policy.
+     * @param step The step of the static schedule to retrieve (default is 0).
+     * @return Vector of shared pointers to Nodes in their scheduled order.
+     */
+    std::vector<std::shared_ptr<Node>> getSequentialStaticScheduling(std::size_t step = 0) const;
+
+    /**
+     * Generate the memory layout for the static scheduling following the
+     * current scheduling policy.
      * @param incProducers If true, include the producers in the memory layout.
      * @param wrapAroundBuffer If true, allow wrapping in memory planes.
     */
     MemoryManager generateMemory(bool incProducers = false, bool wrapAroundBuffer = false) const;
 
     /**
-     * Generate the memory layout for the current static scheduling, with auto-
-     * concatenation: the Concat operator is replaced by direct allocation
-     * when possible.
+     * Generate the memory layout for the static scheduling following the
+     * current scheduling policy, with auto-concatenation: the Concat operator 
+     * is replaced by direct allocation when possible.
      * @param incProducers If true, include the producers in the memory layout.
      * @param wrapAroundBuffer If true, allow wrapping in memory planes.
     */
diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp
index 62bb7e0af..dbe89436f 100644
--- a/python_binding/scheduler/pybind_Scheduler.cpp
+++ b/python_binding/scheduler/pybind_Scheduler.cpp
@@ -59,6 +59,7 @@ void init_Scheduler(py::module& m){
 
     py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>, Scheduler>(m, "SequentialScheduler")
     .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
+    .def("get_sequential_static_scheduling", &SequentialScheduler::getSequentialStaticScheduling, py::arg("step") = 0)
     .def("generate_memory", &SequentialScheduler::generateMemory, py::arg("inc_producers") = false, py::arg("wrap_around_buffer") = false)
     .def("generate_memory_auto_concat", &SequentialScheduler::generateMemoryAutoConcat, py::arg("inc_producers") = false, py::arg("wrap_around_buffer") = false)
     .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("data")=std::vector<Tensor>())
diff --git a/src/scheduler/SequentialScheduler.cpp b/src/scheduler/SequentialScheduler.cpp
index 6a52882f6..c89bf393e 100644
--- a/src/scheduler/SequentialScheduler.cpp
+++ b/src/scheduler/SequentialScheduler.cpp
@@ -29,6 +29,10 @@
 #include "aidge/operator/MetaOperator.hpp"
 #include "aidge/recipes/GraphViewHelper.hpp"
 
+std::vector<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getSequentialStaticScheduling(std::size_t step) const {
+    return Scheduler::getSequentialStaticScheduling(step, mSchedulingPolicy);
+}
+
 /**
  * @warning This version is a simplified version without special handling of concatenation.
  */
@@ -36,7 +40,7 @@ Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducer
     MemoryManager memManager;
 
     for (std::size_t step = 0; step < mStaticSchedule.size(); ++step) {
-        for (const auto& node : getSequentialStaticScheduling(step, mSchedulingPolicy)) {
+        for (const auto& node : getSequentialStaticScheduling(step)) {
             if (!incProducers && node->type() == Producer_Op::Type) {
                 memManager.releaseDependencies(node);
                 continue;
@@ -159,7 +163,7 @@ Aidge::MemoryManager Aidge::SequentialScheduler::generateMemoryAutoConcat(bool i
         // have already been memory mapped!
         // But in any case, the policy must remain the same than the scheduler
         // export policy.
-        for (const auto& node : getSequentialStaticScheduling(step, mSchedulingPolicy)) {
+        for (const auto& node : getSequentialStaticScheduling(step)) {
             if (!incProducers && node->type() == Producer_Op::Type) {
                 memManager.releaseDependencies(node);
                 continue;
@@ -399,7 +403,7 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std
     }
 
     // Sort static scheduling according to the policy
-    const auto nodes = getSequentialStaticScheduling(mStaticScheduleStep, mSchedulingPolicy);
+    const auto nodes = getSequentialStaticScheduling(mStaticScheduleStep);
     const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
 
     for (const auto& runnable : nodes) {
@@ -431,7 +435,7 @@ void Aidge::SequentialScheduler::backward() {
     }
 
     // map of node <-> info to display with verbose
-    const auto nodes = getSequentialStaticScheduling(mStaticScheduleStep, mSchedulingPolicy);
+    const auto nodes = getSequentialStaticScheduling(mStaticScheduleStep);
     const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
 
     // run scheduled operators in reverse order
-- 
GitLab