From 704d69a7dd50ce5eddf8521a41bd27bf4ff210c9 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Sun, 18 Feb 2024 19:02:45 +0100 Subject: [PATCH] Major rework of Scheduler to actually work with MetaOperator --- include/aidge/operator/MetaOperator.hpp | 13 ++- include/aidge/scheduler/Scheduler.hpp | 18 +++- src/operator/MetaOperator.cpp | 5 +- src/scheduler/Scheduler.cpp | 138 ++++++++++++++++++------ 4 files changed, 130 insertions(+), 44 deletions(-) diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index b38a2befe..102f33a37 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -25,6 +25,7 @@ public: // Micro-graph handling: std::shared_ptr<GraphView> mGraph; // Meta operator micro-graph std::shared_ptr<SequentialScheduler> mScheduler; + std::weak_ptr<Node> mUpperNode; public: MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph); @@ -38,6 +39,13 @@ public: mGraph(op.mGraph->clone()) {} + /** + * Set the node that should be used for the scheduling. + */ + void setUpperNode(std::shared_ptr<Node> node) { + mUpperNode = node; + } + /** * @brief Clone the operator using its copy-constructor. * @see Operator::MetaOperator_Op @@ -108,7 +116,10 @@ inline std::shared_ptr<Node> MetaOperator(const char *type, const std::shared_ptr<GraphView>& graph, const std::string& name = "") { - return std::make_shared<Node>(std::make_shared<MetaOperator_Op>(type, graph), name); + auto op = std::make_shared<MetaOperator_Op>(type, graph); + auto node = std::make_shared<Node>(op, name); + op->setUpperNode(node); + return node; } } // namespace Aidge diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 9904683ba..a90b7ea18 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -19,6 +19,8 @@ #include <vector> #include <map> +#include "aidge/utils/Types.h" + namespace Aidge { class Node; class GraphView; @@ -44,8 +46,9 @@ private: }; public: - SequentialScheduler(std::shared_ptr<GraphView> graphView) - : mGraphView(graphView) + SequentialScheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr) + : mGraphView(graphView), + mUpperNode(upperNode) { // ctor }; @@ -55,6 +58,7 @@ public: inline void resetScheduling() { mScheduling.clear(); mStaticSchedule.clear(); + mStaticScheduleStep = 0; } /** @@ -72,8 +76,8 @@ public: * @brief Return a vector of Node ordered by the order they are called by the scheduler * @return std::vector<std::shared_ptr<Node>> */ - inline std::vector<std::shared_ptr<Node>> getStaticScheduling() const noexcept { - return mStaticSchedule; + inline std::vector<std::shared_ptr<Node>> getStaticScheduling(size_t step = 0) const noexcept { + return mStaticSchedule.at(step); } inline std::shared_ptr<GraphView> getGraphView() const noexcept { return mGraphView; @@ -87,14 +91,18 @@ private: * @return std::set<std::shared_ptr<Node>> */ std::set<std::shared_ptr<Node>> getConsumers(const std::set<std::shared_ptr<Node>>& producers) const; + NbElts_t getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const; PriorProducersConsumers getPriorProducersConsumers(const std::shared_ptr<Node>& node) const; /** @brief Shared ptr to the scheduled graph view */ std::shared_ptr<GraphView> mGraphView; + /** @brief Shared ptr to the upper node containing the graph view */ + std::weak_ptr<Node> mUpperNode; /** @brief List of SchedulingElement (i.e: Nodes with their computation time) */ std::vector<SchedulingElement> mScheduling; /** @brief List of nodes ordered by their */ - std::vector<std::shared_ptr<Node>> mStaticSchedule; + std::vector<std::vector<std::shared_ptr<Node>>> mStaticSchedule; + size_t mStaticScheduleStep = 0; }; } // namespace Aidge diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index 530357085..0ff758a56 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -65,10 +65,9 @@ void Aidge::MetaOperator_Op::updateConsummerProducer() { else { if (!mScheduler) { // Lazy initialization - mScheduler = std::make_shared<SequentialScheduler>(mGraph); + mScheduler = std::make_shared<SequentialScheduler>(mGraph, mUpperNode.lock()); } - // TODO: check that generateScheduling() can be called multiple time to iteratively update the schedule. // It could be a good idea to unify updateConsummerProducer() and generateScheduling() into a "updateScheduling()" mScheduler->generateScheduling(); @@ -86,7 +85,7 @@ void Aidge::MetaOperator_Op::forward() { // Lazy initialization // TODO: should we assert that a scheduler already exists at this point? // => should be created in updateConsummerProducer() - mScheduler = std::make_shared<SequentialScheduler>(mGraph); + mScheduler = std::make_shared<SequentialScheduler>(mGraph, mUpperNode.lock()); mScheduler->generateScheduling(); } diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 646025d23..d74d1980c 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -24,6 +24,7 @@ #include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/operator/Memorize.hpp" +#include "aidge/operator/MetaOperator.hpp" void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") { putchar('['); @@ -68,6 +69,13 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { std::map<std::shared_ptr<Node>, std::string> namePtrTable; if (verbose) namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); + // Still consumers are consumers that were run by can still consume data. + // They must be run AFTER the remaining consumer to ensure a non-greedy + // producers-consumers model! + std::set<std::shared_ptr<Node>> stillConsumers; + + mStaticSchedule.push_back(std::vector<std::shared_ptr<Node>>()); + do { // From the current consumers list, check if any prior nodes are needed. // If for a given node, only parent producers (at any depth) are needed @@ -121,7 +129,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // Make producers generate the required data for (const auto& requiredProducer : requiredProducers) { requiredProducer->getOperator()->updateConsummerProducer(); - mStaticSchedule.push_back(requiredProducer); + mStaticSchedule.back().push_back(requiredProducer); } // find runnable consumers @@ -150,23 +158,19 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { } bool isRunnable = true; - - IOIndex_t inputIdx = 0; // FIXME: handle this correctly - // Check every input has got enought data to run - for (const auto& consumerParent : consumer->inputs()) { - if (consumerParent.first && - (consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) > - consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)) { - if (verbose) printf(" not runnable: C%zu + R%zu > P%zu\n", + for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) { + if (/*consumer->getOperator()->getNbRequiredData(inputIdx) > 0 + && */(consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) > + getNbAvailableData(consumer, inputIdx)) { + if (verbose) printf(" not runnable: C%zu + R%zu > P%zu for input #%d\n", consumer->getOperator()->getNbConsumedData(inputIdx), consumer->getOperator()->getNbRequiredData(inputIdx), - consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)); + getNbAvailableData(consumer, inputIdx), inputIdx); // not enough data to run isRunnable = false; break; } - ++inputIdx; } if (isRunnable) { @@ -178,11 +182,16 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { for (const auto& runnable : runnableConsumers) { if (verbose) printf("Runnable: %s\n", namePtrTable[runnable].c_str()); runnable->getOperator()->updateConsummerProducer(); - mStaticSchedule.push_back(runnable); + mStaticSchedule.back().push_back(runnable); } if (runnableConsumers.empty()) { - frozenConsumers.push_back(consumers); + if (std::find(frozenConsumers.begin(), frozenConsumers.end(), consumers) == frozenConsumers.end()) { + frozenConsumers.push_back(consumers); + } + else { + break; + } } else { frozenConsumers.clear(); @@ -209,23 +218,19 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { printf("%zu", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); printf("\n"); } - bool isStillConsumer = false; - IOIndex_t inputIdx = 0; // FIXME: handle this correctly - // should we check input or dataInput ? - for (const auto& consumerParent : consumer->inputs()) { - if (consumerParent.first && - consumer->getOperator()->getNbConsumedData(inputIdx) < - consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)) { - if (verbose) printf(" still consumer: C%zu < P%zu\n", + bool isStillConsumer = false; + for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) { + if (consumer->getOperator()->getNbConsumedData(inputIdx) < + getNbAvailableData(consumer, inputIdx)) { + if (verbose) printf(" still consumer: C%zu < P%zu for input #%d\n", consumer->getOperator()->getNbConsumedData(inputIdx), - consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)); + getNbAvailableData(consumer, inputIdx), inputIdx); // there is still data to consume isStillConsumer = true; break; } - ++inputIdx; } for (IOIndex_t outId = 0; outId < consumer->nbOutputs(); ++outId) { @@ -234,21 +239,34 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // make sure consumer is also a producer producers.insert(consumer); - const auto& childs = consumer->getChildren(); - consumers.insert(childs.begin(), childs.end()); + const auto& newConsumers = getConsumers({consumer}); + consumers.insert(newConsumers.cbegin(), newConsumers.cend()); break; } } - if (!isStillConsumer) { - if (verbose) printf(" no more consumer\n"); - // consumer is no longer a consumer, only a producer + if (runnableConsumers.find(consumer) != runnableConsumers.end()) { + // If consumer was run, remove it from the consumers list for + // now consumers.erase(consumer); + if (isStillConsumer) { + // If there is still data to consume, the consumer will be + // run AFTER the other remaining consumers + // (= non-greedy consumers) + stillConsumers.insert(consumer); + } } } + // If there is no more consumers, swap with possible "still consumers" + // This ensures that the "non-greedy" consumer behavior + if (consumers.empty()) { + consumers.swap(stillConsumers); + stillConsumers.clear(); + } + if (verbose) printf("********************\n"); - } while (!consumers.empty() && (frozenConsumers.empty() || std::find(frozenConsumers.begin(), frozenConsumers.end(), consumers) == frozenConsumers.end())); + } while (!consumers.empty()); if (verbose) { if (!consumers.empty()) { @@ -270,14 +288,11 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { this->generateScheduling(verbose); } - // Clear previous scheduling results - mScheduling.clear(); - std::map<std::shared_ptr<Node>, std::string> namePtrTable; if (verbose) namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); - int cpt = 0; - for (const auto& runnable : mStaticSchedule) { + size_t cpt = 0; + for (const auto& runnable : mStaticSchedule.at(mStaticScheduleStep)) { if (verbose) printf("run: %s\n", namePtrTable[runnable].c_str()); @@ -292,6 +307,8 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { } if (!verbose) drawProgressBar(1.0, 50, " "); printf("\n"); + + ++mStaticScheduleStep; } void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const { @@ -321,12 +338,58 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers( for (const auto& producer : producers) { const auto& childs = producer->getChildren(); - consumers.insert(childs.begin(), childs.end()); + for (const auto& child : childs) { + // Do not schedule childs outside current graph! + if (mGraphView->inView(child)) { + consumers.insert(child); + } + } } return consumers; } +Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const { + const auto parent = node->inputs()[inputIdx]; + + if (parent.first) { + // Parent is connected, everything if fine! + return parent.first->getOperator()->getNbProducedData(parent.second); + } + else if (std::shared_ptr<Node> upperNode = mUpperNode.lock()) { + // We are inside an upper operator (for instance a MetaOperator) + // We need to connect the "local" producer-consumer model to the upper + // one, by mapping local node inputs to the upper node inputs. + IOIndex_t nodeInputIdx = 0; + for (const auto& input : mGraphView->getOrderedInputs()) { + if (input.first == node) { + // Current node is an input + const auto upperInput = upperNode->inputs()[nodeInputIdx]; + if (upperInput.first) { + return upperInput.first->getOperator()->getNbProducedData(upperInput.second); + } + } + ++nodeInputIdx; + } + } + + // Otherwise, two cases: + if (node->getOperator()->getRawInput(inputIdx)) { + // Input is not connected but a valid tensor exists + // => This means data was fed manually to the input, without a Producer + // In this case, we assume a single-use data (unlike a Producer, which + // keep producing the data each time it is needed). + fmt::print("No producer node attached to input#{} for node {} ({})\n", inputIdx, node->name(), node->type()); + return std::static_pointer_cast<Tensor>(node->getOperator()->getRawInput(inputIdx))->size(); + } + else { + // Input is not connected, this is an error + AIDGE_THROW_OR_ABORT(std::runtime_error, "Missing input#{} for node {} ({})\n", inputIdx, node->name(), node->type()); + } + + return 0; +} + Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::getPriorProducersConsumers( const std::shared_ptr<Node>& node) const { @@ -338,6 +401,11 @@ Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler:: (node->getOperator()->getNbConsumedData(inputIdx) + node->getOperator()->getNbRequiredData(inputIdx)) > parent.first->getOperator()->getNbProducedData(parent.second)) { + if (!mGraphView->inView(parent.first)) { + // Do not schedule prior outside the current graph! + return PriorProducersConsumers(); + } + if (parent.first->type() == Producer_Op::Type) { prior.requiredProducers.insert(parent.first); prior.priorConsumers.insert(node); -- GitLab