From 704d69a7dd50ce5eddf8521a41bd27bf4ff210c9 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Sun, 18 Feb 2024 19:02:45 +0100
Subject: [PATCH] Major rework of Scheduler to actually work with MetaOperator

---
 include/aidge/operator/MetaOperator.hpp |  13 ++-
 include/aidge/scheduler/Scheduler.hpp   |  18 +++-
 src/operator/MetaOperator.cpp           |   5 +-
 src/scheduler/Scheduler.cpp             | 138 ++++++++++++++++++------
 4 files changed, 130 insertions(+), 44 deletions(-)

diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp
index b38a2befe..102f33a37 100644
--- a/include/aidge/operator/MetaOperator.hpp
+++ b/include/aidge/operator/MetaOperator.hpp
@@ -25,6 +25,7 @@ public:
     // Micro-graph handling:
     std::shared_ptr<GraphView> mGraph; // Meta operator micro-graph
     std::shared_ptr<SequentialScheduler> mScheduler;
+    std::weak_ptr<Node> mUpperNode;
 
    public:
     MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph);
@@ -38,6 +39,13 @@ public:
           mGraph(op.mGraph->clone())
     {}
 
+    /**
+     * Set the node that should be used for the scheduling.
+    */
+    void setUpperNode(std::shared_ptr<Node> node) {
+        mUpperNode = node;
+    }
+
     /**
      * @brief Clone the operator using its copy-constructor.
      * @see Operator::MetaOperator_Op
@@ -108,7 +116,10 @@ inline std::shared_ptr<Node> MetaOperator(const char *type,
                                   const std::shared_ptr<GraphView>& graph,
                                   const std::string& name = "")
 {
-    return std::make_shared<Node>(std::make_shared<MetaOperator_Op>(type, graph), name);
+    auto op = std::make_shared<MetaOperator_Op>(type, graph);
+    auto node = std::make_shared<Node>(op, name);
+    op->setUpperNode(node);
+    return node;
 }
 }  // namespace Aidge
 
diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp
index 9904683ba..a90b7ea18 100644
--- a/include/aidge/scheduler/Scheduler.hpp
+++ b/include/aidge/scheduler/Scheduler.hpp
@@ -19,6 +19,8 @@
 #include <vector>
 #include <map>
 
+#include "aidge/utils/Types.h"
+
 namespace Aidge {
 class Node;
 class GraphView;
@@ -44,8 +46,9 @@ private:
     };
 
 public:
-    SequentialScheduler(std::shared_ptr<GraphView> graphView)
-        : mGraphView(graphView)
+    SequentialScheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr)
+        : mGraphView(graphView),
+          mUpperNode(upperNode)
     {
         // ctor
     };
@@ -55,6 +58,7 @@ public:
     inline void resetScheduling() {
         mScheduling.clear();
         mStaticSchedule.clear();
+        mStaticScheduleStep = 0;
     }
 
     /**
@@ -72,8 +76,8 @@ public:
      * @brief Return a vector of Node ordered by the order they are called by the scheduler
      * @return std::vector<std::shared_ptr<Node>>
      */
-    inline std::vector<std::shared_ptr<Node>> getStaticScheduling() const noexcept {
-        return mStaticSchedule;
+    inline std::vector<std::shared_ptr<Node>> getStaticScheduling(size_t step = 0) const noexcept {
+        return mStaticSchedule.at(step);
     }
     inline std::shared_ptr<GraphView> getGraphView() const noexcept {
         return mGraphView;
@@ -87,14 +91,18 @@ private:
      * @return std::set<std::shared_ptr<Node>>
      */
     std::set<std::shared_ptr<Node>> getConsumers(const std::set<std::shared_ptr<Node>>& producers) const;
+    NbElts_t getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const;
     PriorProducersConsumers getPriorProducersConsumers(const std::shared_ptr<Node>& node) const;
 
     /** @brief Shared ptr to the scheduled graph view */
     std::shared_ptr<GraphView> mGraphView;
+    /** @brief Shared ptr to the upper node containing the graph view */
+    std::weak_ptr<Node> mUpperNode;
     /** @brief List of SchedulingElement (i.e: Nodes with their computation time) */
     std::vector<SchedulingElement> mScheduling;
     /** @brief List of nodes ordered by their */
-    std::vector<std::shared_ptr<Node>> mStaticSchedule;
+    std::vector<std::vector<std::shared_ptr<Node>>> mStaticSchedule;
+    size_t mStaticScheduleStep = 0;
 };
 } // namespace Aidge
 
diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp
index 530357085..0ff758a56 100644
--- a/src/operator/MetaOperator.cpp
+++ b/src/operator/MetaOperator.cpp
@@ -65,10 +65,9 @@ void Aidge::MetaOperator_Op::updateConsummerProducer() {
     else {
         if (!mScheduler) {
             // Lazy initialization
-            mScheduler = std::make_shared<SequentialScheduler>(mGraph);
+            mScheduler = std::make_shared<SequentialScheduler>(mGraph, mUpperNode.lock());
         }
 
-
         // TODO: check that generateScheduling() can be called multiple time to iteratively update the schedule.
         // It could be a good idea to unify updateConsummerProducer() and generateScheduling() into a "updateScheduling()"
         mScheduler->generateScheduling();
@@ -86,7 +85,7 @@ void Aidge::MetaOperator_Op::forward() {
             // Lazy initialization
             // TODO: should we assert that a scheduler already exists at this point?
             // => should be created in updateConsummerProducer()
-            mScheduler = std::make_shared<SequentialScheduler>(mGraph);
+            mScheduler = std::make_shared<SequentialScheduler>(mGraph, mUpperNode.lock());
             mScheduler->generateScheduling();
         }
 
diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp
index 646025d23..d74d1980c 100644
--- a/src/scheduler/Scheduler.cpp
+++ b/src/scheduler/Scheduler.cpp
@@ -24,6 +24,7 @@
 #include "aidge/operator/OperatorTensor.hpp"
 #include "aidge/operator/Producer.hpp"
 #include "aidge/operator/Memorize.hpp"
+#include "aidge/operator/MetaOperator.hpp"
 
 void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") {
     putchar('[');
@@ -68,6 +69,13 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
     std::map<std::shared_ptr<Node>, std::string> namePtrTable;
     if (verbose) namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
 
+    // Still consumers are consumers that were run by can still consume data.
+    // They must be run AFTER the remaining consumer to ensure a non-greedy
+    // producers-consumers model!
+    std::set<std::shared_ptr<Node>> stillConsumers;
+
+    mStaticSchedule.push_back(std::vector<std::shared_ptr<Node>>());
+
     do {
         // From the current consumers list, check if any prior nodes are needed.
         // If for a given node, only parent producers (at any depth) are needed
@@ -121,7 +129,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
         // Make producers generate the required data
         for (const auto& requiredProducer : requiredProducers) {
             requiredProducer->getOperator()->updateConsummerProducer();
-            mStaticSchedule.push_back(requiredProducer);
+            mStaticSchedule.back().push_back(requiredProducer);
         }
 
         // find runnable consumers
@@ -150,23 +158,19 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
             }
             
             bool isRunnable = true;
-
-            IOIndex_t inputIdx = 0;  // FIXME: handle this correctly
-            // Check every input has got enought data to run
-            for (const auto& consumerParent : consumer->inputs()) {
-                if (consumerParent.first &&
-                    (consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) >
-                            consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)) {
-                    if (verbose) printf("  not runnable: C%zu + R%zu > P%zu\n",
+            for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) {
+                if (/*consumer->getOperator()->getNbRequiredData(inputIdx) > 0
+                    && */(consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) >
+                            getNbAvailableData(consumer, inputIdx)) {
+                    if (verbose) printf("  not runnable: C%zu + R%zu > P%zu for input #%d\n",
                         consumer->getOperator()->getNbConsumedData(inputIdx),
                         consumer->getOperator()->getNbRequiredData(inputIdx),
-                        consumerParent.first->getOperator()->getNbProducedData(consumerParent.second));
+                        getNbAvailableData(consumer, inputIdx), inputIdx);
 
                     // not enough data to run
                     isRunnable = false;
                     break;
                 }
-                ++inputIdx;
             }
 
             if (isRunnable) {
@@ -178,11 +182,16 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
         for (const auto& runnable : runnableConsumers) {
             if (verbose) printf("Runnable: %s\n", namePtrTable[runnable].c_str());
             runnable->getOperator()->updateConsummerProducer();
-            mStaticSchedule.push_back(runnable);
+            mStaticSchedule.back().push_back(runnable);
         }
 
         if (runnableConsumers.empty()) {
-            frozenConsumers.push_back(consumers);
+            if (std::find(frozenConsumers.begin(), frozenConsumers.end(), consumers) == frozenConsumers.end()) {
+                frozenConsumers.push_back(consumers);
+            }
+            else {
+                break;
+            }
         }
         else {
             frozenConsumers.clear();
@@ -209,23 +218,19 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
                 printf("%zu", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1));
                 printf("\n");
             }
-            bool isStillConsumer = false;
 
-            IOIndex_t inputIdx = 0;  // FIXME: handle this correctly
-            // should we check input or dataInput ?
-            for (const auto& consumerParent : consumer->inputs()) {
-                if (consumerParent.first &&
-                    consumer->getOperator()->getNbConsumedData(inputIdx) <
-                            consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)) {
-                    if (verbose) printf("  still consumer: C%zu < P%zu\n",
+            bool isStillConsumer = false;
+            for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) {
+                if (consumer->getOperator()->getNbConsumedData(inputIdx) <
+                            getNbAvailableData(consumer, inputIdx)) {
+                    if (verbose) printf("  still consumer: C%zu < P%zu for input #%d\n",
                         consumer->getOperator()->getNbConsumedData(inputIdx),
-                        consumerParent.first->getOperator()->getNbProducedData(consumerParent.second));
+                        getNbAvailableData(consumer, inputIdx), inputIdx);
 
                     // there is still data to consume
                     isStillConsumer = true;
                     break;
                 }
-                ++inputIdx;
             }
 
             for (IOIndex_t outId = 0; outId < consumer->nbOutputs(); ++outId) {
@@ -234,21 +239,34 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
                     // make sure consumer is also a producer
                     producers.insert(consumer);
 
-                    const auto& childs = consumer->getChildren();
-                    consumers.insert(childs.begin(), childs.end());
+                    const auto& newConsumers = getConsumers({consumer});
+                    consumers.insert(newConsumers.cbegin(), newConsumers.cend());
                     break;
                 }
             }
 
-            if (!isStillConsumer) {
-                if (verbose) printf("  no more consumer\n");
-                // consumer is no longer a consumer, only a producer
+            if (runnableConsumers.find(consumer) != runnableConsumers.end()) {
+                // If consumer was run, remove it from the consumers list for
+                // now
                 consumers.erase(consumer);
+                if (isStillConsumer) {
+                    // If there is still data to consume, the consumer will be
+                    // run AFTER the other remaining consumers
+                    // (= non-greedy consumers)
+                    stillConsumers.insert(consumer);
+                }
             }
         }
 
+        // If there is no more consumers, swap with possible "still consumers"
+        // This ensures that the "non-greedy" consumer behavior
+        if (consumers.empty()) {
+            consumers.swap(stillConsumers);
+            stillConsumers.clear();
+        }
+
         if (verbose) printf("********************\n");
-    } while (!consumers.empty() && (frozenConsumers.empty() || std::find(frozenConsumers.begin(), frozenConsumers.end(), consumers) == frozenConsumers.end()));
+    } while (!consumers.empty());
 
     if (verbose) {
         if (!consumers.empty()) {
@@ -270,14 +288,11 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
         this->generateScheduling(verbose);
     }
 
-    // Clear previous scheduling results
-    mScheduling.clear();
-
     std::map<std::shared_ptr<Node>, std::string> namePtrTable;
     if (verbose) namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
 
-    int cpt = 0;
-    for (const auto& runnable : mStaticSchedule) {
+    size_t cpt = 0;
+    for (const auto& runnable : mStaticSchedule.at(mStaticScheduleStep)) {
         if (verbose)
             printf("run: %s\n",
                     namePtrTable[runnable].c_str());
@@ -292,6 +307,8 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
     }
     if (!verbose) drawProgressBar(1.0, 50, "                                   ");
     printf("\n");
+
+    ++mStaticScheduleStep;
 }
 
 void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const {
@@ -321,12 +338,58 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers(
 
     for (const auto& producer : producers) {
         const auto& childs = producer->getChildren();
-        consumers.insert(childs.begin(), childs.end());
+        for (const auto& child : childs) {
+            // Do not schedule childs outside current graph!
+            if (mGraphView->inView(child)) {
+                consumers.insert(child);
+            }
+        }
     }
 
     return consumers;
 }
 
+Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const {
+    const auto parent = node->inputs()[inputIdx];
+
+    if (parent.first) {
+        // Parent is connected, everything if fine!
+        return parent.first->getOperator()->getNbProducedData(parent.second);
+    }
+    else if (std::shared_ptr<Node> upperNode = mUpperNode.lock()) {
+        // We are inside an upper operator (for instance a MetaOperator)
+        // We need to connect the "local" producer-consumer model to the upper
+        // one, by mapping local node inputs to the upper node inputs.
+        IOIndex_t nodeInputIdx = 0;
+        for (const auto& input : mGraphView->getOrderedInputs()) {
+            if (input.first == node) {
+                // Current node is an input
+                const auto upperInput = upperNode->inputs()[nodeInputIdx];
+                if (upperInput.first) {
+                    return upperInput.first->getOperator()->getNbProducedData(upperInput.second);
+                } 
+            }
+            ++nodeInputIdx;
+        }
+    }
+
+    // Otherwise, two cases:
+    if (node->getOperator()->getRawInput(inputIdx)) {
+        // Input is not connected but a valid tensor exists
+        // => This means data was fed manually to the input, without a Producer
+        // In this case, we assume a single-use data (unlike a Producer, which
+        // keep producing the data each time it is needed).
+        fmt::print("No producer node attached to input#{} for node {} ({})\n", inputIdx, node->name(), node->type());
+        return std::static_pointer_cast<Tensor>(node->getOperator()->getRawInput(inputIdx))->size();
+    }
+    else {
+        // Input is not connected, this is an error
+        AIDGE_THROW_OR_ABORT(std::runtime_error, "Missing input#{} for node {} ({})\n", inputIdx, node->name(), node->type());
+    }
+
+    return 0;
+}
+
 Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::getPriorProducersConsumers(
     const std::shared_ptr<Node>& node) const
 {
@@ -338,6 +401,11 @@ Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::
             (node->getOperator()->getNbConsumedData(inputIdx) + node->getOperator()->getNbRequiredData(inputIdx)) >
                     parent.first->getOperator()->getNbProducedData(parent.second))
         {
+            if (!mGraphView->inView(parent.first)) {
+                // Do not schedule prior outside the current graph!
+                return PriorProducersConsumers();
+            }
+
             if (parent.first->type() == Producer_Op::Type) {
                 prior.requiredProducers.insert(parent.first);
                 prior.priorConsumers.insert(node);
-- 
GitLab