From 36994a1c80f18668750eea2cd921fb57e84e4f56 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Tue, 5 Sep 2023 13:07:11 +0000
Subject: [PATCH] [Scheduling] Refactor to separate scheduling & forward phase.

---
 include/aidge/backend/OperatorImpl.hpp        | 10 ++-
 include/aidge/operator/Operator.hpp           |  2 +
 include/aidge/scheduler/Scheduler.hpp         | 28 ++++--
 .../operator/pybind_GenericOperator.cpp       |  6 +-
 python_binding/scheduler/pybind_Scheduler.cpp |  3 +-
 src/operator/Operator.cpp                     |  3 +
 src/scheduler/Scheduler.cpp                   | 88 +++++++++++--------
 7 files changed, 91 insertions(+), 49 deletions(-)

diff --git a/include/aidge/backend/OperatorImpl.hpp b/include/aidge/backend/OperatorImpl.hpp
index 5aa2829e1..d10270b62 100644
--- a/include/aidge/backend/OperatorImpl.hpp
+++ b/include/aidge/backend/OperatorImpl.hpp
@@ -20,7 +20,7 @@ namespace Aidge {
 class OperatorImpl {
 public:
     virtual void forward(){};
-    virtual void backward() {}
+    virtual void backward(){};
 
     /**
      * @brief Minimum amount of data from a specific input required by the
@@ -46,13 +46,19 @@ public:
     virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const = 0;
 
     /**
-     * @brief TOtal amount of produced data ready to be used on a specific output.
+     * @brief Total amount of produced data ready to be used on a specific output.
      *
      * @param outputIdx Index of the output analysed.
      * @return DimSize_t
      */
     virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const = 0;
 
+    /**
+     * @brief Update the Consummer Producer system by simulating the consumption and production of i/o
+     *
+     */
+    virtual void updateConsummerProducer() = 0;
+
     virtual ~OperatorImpl() = default;
 };
 } // namespace Aidge
diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp
index 30e1ce2a7..36f846dda 100644
--- a/include/aidge/operator/Operator.hpp
+++ b/include/aidge/operator/Operator.hpp
@@ -78,6 +78,8 @@ public:
      */
     NbElts_t getNbProducedData(const IOIndex_t outputIdx) const;
 
+    void updateConsummerProducer();
+
     virtual void forward();
 
     virtual void backward();
diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp
index fb39bcf4a..9916ee200 100644
--- a/include/aidge/scheduler/Scheduler.hpp
+++ b/include/aidge/scheduler/Scheduler.hpp
@@ -43,6 +43,8 @@ public:
     };
     ~SequentialScheduler() = default;
 
+    void generateScheduling(bool verbose = false);
+
     /**
      * @brief Run the provided Computational Graph with a batch of data
      */
@@ -59,12 +61,8 @@ public:
      *
      * @return std::vector<std::shared_ptr<Node>>
      */
-    std::vector<std::shared_ptr<Node>> getNodeScheduling(){
-        std::vector<std::shared_ptr<Node>> nodeScheduling = {};
-        for(SchedulingElement & scheduleElt: mScheduling){
-            nodeScheduling.push_back(scheduleElt.node);
-        }
-        return nodeScheduling;
+    std::vector<std::shared_ptr<Node>> getStaticScheduling(){
+        return mStaticSchedule;
     }
 
 private:
@@ -76,8 +74,26 @@ private:
      */
     std::set<std::shared_ptr<Node>> getConsumers(const std::set<std::shared_ptr<Node>>& producers) const;
 
+    /**
+     * @brief Shared ptr to the scheduled graph view
+     *
+     */
     std::shared_ptr<GraphView> mGraphView;
+    /**
+     * @brief List of SchedulingElement (i.e: Nodes with their computation time)
+     *
+     */
     std::vector<SchedulingElement> mScheduling;
+    /**
+     * @brief List of nodes ordered by their
+     *
+     */
+    std::vector<std::shared_ptr<Node>> mStaticSchedule;
+    /**
+     * @brief Number of computation node (i.e: nb nodes != Producer)
+     *
+     */
+    std::size_t mComputationNumber = 0; // TODO: Check if not inferable from mStaticSchedule
 };
 } // namespace Aidge
 
diff --git a/python_binding/operator/pybind_GenericOperator.cpp b/python_binding/operator/pybind_GenericOperator.cpp
index ee3ee74c1..bec59eaf2 100644
--- a/python_binding/operator/pybind_GenericOperator.cpp
+++ b/python_binding/operator/pybind_GenericOperator.cpp
@@ -22,7 +22,7 @@ namespace Aidge {
 void init_GenericOperator(py::module& m) {
     py::class_<GenericOperator_Op, std::shared_ptr<GenericOperator_Op>, Operator>(m, "GenericOperatorOp",
                                                                                   py::multiple_inheritance())
-            .def("get_parameter_type", &GenericOperator_Op::getParameterType)
+    .def("get_parameter_type", &GenericOperator_Op::getParameterType)
     .def("get_parameters_name", &GenericOperator_Op::getParametersName)
     .def("add_parameter", &GenericOperator_Op::addParameter<bool>)
     .def("add_parameter", &GenericOperator_Op::addParameter<int>)
@@ -34,10 +34,10 @@ void init_GenericOperator(py::module& m) {
     .def("add_parameter", &GenericOperator_Op::addParameter<std::vector<std::string>>)
     .def("get_parameter", [](GenericOperator_Op& self, std::string key) -> py::object {
         /*
-        This getParameter method returns the good python type without having to have 
+        This getParameter method returns the good python type without having to have
         prior knowledge of the parameter type.
         */
-        py::object res = py::none(); 
+        py::object res = py::none();
         std::string paramType = self.getParameterType(key);
         if(paramType == typeid(int).name())
             res = py::cast(self.getParameter<int>(key));
diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp
index ab7d3d850..85479d41f 100644
--- a/python_binding/scheduler/pybind_Scheduler.cpp
+++ b/python_binding/scheduler/pybind_Scheduler.cpp
@@ -21,7 +21,8 @@ void init_Scheduler(py::module& m){
     .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
     .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("verbose")=false)
     .def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name"))
-    .def("get_node_scheduling", &SequentialScheduler::getNodeScheduling)
+    .def("generate_scheduling", &SequentialScheduler::generateScheduling, py::arg("verbose")=false)
+    .def("get_static_scheduling", &SequentialScheduler::getStaticScheduling)
     ;
 }
 }
diff --git a/src/operator/Operator.cpp b/src/operator/Operator.cpp
index 99b07235e..b3896b121 100644
--- a/src/operator/Operator.cpp
+++ b/src/operator/Operator.cpp
@@ -38,6 +38,9 @@ Aidge::NbElts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) co
 Aidge::NbElts_t Aidge::Operator::getNbProducedData(Aidge::IOIndex_t outputIdx) const {
     return mImpl->getNbProducedData(outputIdx);
 }
+void Aidge::Operator::updateConsummerProducer(){
+    mImpl->updateConsummerProducer();
+}
 
 void Aidge::Operator::forward() { mImpl->forward(); }
 
diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp
index c8cc0a51d..dc0768d2b 100644
--- a/src/scheduler/Scheduler.cpp
+++ b/src/scheduler/Scheduler.cpp
@@ -33,26 +33,19 @@ void drawProgressBar(double progress, int barWidth, const std::string& additiona
     fflush(stdout);
 }
 
-// TODO: handle multiple inputs/outputs
-void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
-    if (forwardDims) {mGraphView->forwardDims(); }
-
-    mScheduling.clear();
-
+void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
     // setup initial producers list
-    // add each Producer Node.
-    std::set<std::shared_ptr<Node>> computationOver;
-    std::size_t computationNumber = 0;
+    mComputationNumber = 0;
     std::set<std::shared_ptr<Node>> producers;
     for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) {
         if (nodePtr->type() == "Producer") {
             producers.insert(nodePtr);
         } else {
-            ++computationNumber;
+            ++mComputationNumber;
         }
     }
     // add Data Input
-    // FIXME : shoudl be changed when the real system for providing
+    // FIXME : should be changed when the real system for providing
     // data is implemented
     for (const std::shared_ptr<Node>& nodePtr : mGraphView->inputNodes()) {
         for (const auto& parentPtr : nodePtr->getParents()) {
@@ -112,21 +105,10 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
             }
         }
 
-        // run sequencially every runnable consumers once
-        // TODO: handle memory allocation in scheduler
-        // TODO: optimize memory usage
+        // Push consumers in the list of nodes to run and update the consumer producer system
         for (const auto& runnable : runnableConsumers) {
-            if (verbose)
-                printf("run: %s\n",
-                       (runnable->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))).c_str());
-            else
-                drawProgressBar(static_cast<float>(computationOver.size()) / static_cast<float>(computationNumber), 50,
-                                (std::string("running ") + runnable->type() + "_" +
-                                 std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))));
-            const auto tStart = std::chrono::high_resolution_clock::now();
-            runnable->forward();
-            const auto tEnd = std::chrono::high_resolution_clock::now();
-            mScheduling.push_back(SchedulingElement(runnable, tStart, tEnd));
+            runnable->getOperator()->updateConsummerProducer();
+            mStaticSchedule.push_back(runnable);
         }
 
         // update producers and consumers list
@@ -164,18 +146,6 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
                 }
             }
 
-            bool computationOverForConsumer = true;
-            for (IOIndex_t parentIDi = 0; parentIDi < consumer->nbInputs(); ++parentIDi) {
-                if (consumer->getOperator()->getNbConsumedData(parentIDi) <
-                    consumer->getOperator()->getNbRequiredData(parentIDi)) {
-                    computationOverForConsumer = false;
-                    break;
-                }
-            }
-            if (computationOverForConsumer) {
-                computationOver.insert(consumer);
-            }
-
             for (IOIndex_t outId = 0; outId < consumer->nbOutputs(); ++outId) {
                 if (consumer->getOperator()->getNbProducedData(outId) > 0) {
                     if (verbose) printf("  also producer\n");
@@ -197,8 +167,52 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
 
         if (verbose) printf("*************\n");
     } while (!consumers.empty());
+
+}
+
+// TODO: handle multiple inputs/outputs
+void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
+    if (forwardDims) {mGraphView->forwardDims(); }
+
+    // add each Producer Node.
+    std::set<std::shared_ptr<Node>> computationOver;
+
+    mScheduling.clear();
+
+    this->generateScheduling();
+
+    // TODO: For loop on the list of node to run
+    // run sequencially every runnable consumers once
+    // TODO: handle memory allocation in scheduler
+    // TODO: optimize memory usage
+    for (const auto& runnable : mStaticSchedule) {
+        bool computationOverForConsumer = true;
+        for (IOIndex_t parentIDi = 0; parentIDi < runnable->nbInputs(); ++parentIDi) {
+            if (runnable->getOperator()->getNbConsumedData(parentIDi) <
+                runnable->getOperator()->getNbRequiredData(parentIDi)) {
+                computationOverForConsumer = false;
+                break;
+            }
+        }
+        if (computationOverForConsumer) {
+            computationOver.insert(runnable);
+        }
+
+        if (verbose)
+            printf("run: %s\n",
+                    (runnable->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))).c_str());
+        else
+            drawProgressBar(static_cast<float>(computationOver.size()) / static_cast<float>(mComputationNumber), 50,
+                            (std::string("running ") + runnable->type() + "_" +
+                                std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))));
+        const auto tStart = std::chrono::high_resolution_clock::now();
+        runnable->forward();
+        const auto tEnd = std::chrono::high_resolution_clock::now();
+        mScheduling.push_back(SchedulingElement(runnable, tStart, tEnd));
+    }
     if (!verbose) drawProgressBar(1.0, 50, "                                   ");
     printf("\n");
+
 }
 
 void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const {
-- 
GitLab