Skip to content
Snippets Groups Projects
Commit 704d69a7 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Major rework of Scheduler to actually work with MetaOperator

parent 2787ffb0
No related branches found
No related tags found
No related merge requests found
......@@ -25,6 +25,7 @@ public:
// Micro-graph handling:
std::shared_ptr<GraphView> mGraph; // Meta operator micro-graph
std::shared_ptr<SequentialScheduler> mScheduler;
std::weak_ptr<Node> mUpperNode;
public:
MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph);
......@@ -38,6 +39,13 @@ public:
mGraph(op.mGraph->clone())
{}
/**
* Set the node that should be used for the scheduling.
*/
void setUpperNode(std::shared_ptr<Node> node) {
mUpperNode = node;
}
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::MetaOperator_Op
......@@ -108,7 +116,10 @@ inline std::shared_ptr<Node> MetaOperator(const char *type,
const std::shared_ptr<GraphView>& graph,
const std::string& name = "")
{
return std::make_shared<Node>(std::make_shared<MetaOperator_Op>(type, graph), name);
auto op = std::make_shared<MetaOperator_Op>(type, graph);
auto node = std::make_shared<Node>(op, name);
op->setUpperNode(node);
return node;
}
} // namespace Aidge
......
......@@ -19,6 +19,8 @@
#include <vector>
#include <map>
#include "aidge/utils/Types.h"
namespace Aidge {
class Node;
class GraphView;
......@@ -44,8 +46,9 @@ private:
};
public:
SequentialScheduler(std::shared_ptr<GraphView> graphView)
: mGraphView(graphView)
SequentialScheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr)
: mGraphView(graphView),
mUpperNode(upperNode)
{
// ctor
};
......@@ -55,6 +58,7 @@ public:
inline void resetScheduling() {
mScheduling.clear();
mStaticSchedule.clear();
mStaticScheduleStep = 0;
}
/**
......@@ -72,8 +76,8 @@ public:
* @brief Return a vector of Node ordered by the order they are called by the scheduler
* @return std::vector<std::shared_ptr<Node>>
*/
inline std::vector<std::shared_ptr<Node>> getStaticScheduling() const noexcept {
return mStaticSchedule;
inline std::vector<std::shared_ptr<Node>> getStaticScheduling(size_t step = 0) const noexcept {
return mStaticSchedule.at(step);
}
inline std::shared_ptr<GraphView> getGraphView() const noexcept {
return mGraphView;
......@@ -87,14 +91,18 @@ private:
* @return std::set<std::shared_ptr<Node>>
*/
std::set<std::shared_ptr<Node>> getConsumers(const std::set<std::shared_ptr<Node>>& producers) const;
NbElts_t getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const;
PriorProducersConsumers getPriorProducersConsumers(const std::shared_ptr<Node>& node) const;
/** @brief Shared ptr to the scheduled graph view */
std::shared_ptr<GraphView> mGraphView;
/** @brief Shared ptr to the upper node containing the graph view */
std::weak_ptr<Node> mUpperNode;
/** @brief List of SchedulingElement (i.e: Nodes with their computation time) */
std::vector<SchedulingElement> mScheduling;
/** @brief List of nodes ordered by their */
std::vector<std::shared_ptr<Node>> mStaticSchedule;
std::vector<std::vector<std::shared_ptr<Node>>> mStaticSchedule;
size_t mStaticScheduleStep = 0;
};
} // namespace Aidge
......
......@@ -65,10 +65,9 @@ void Aidge::MetaOperator_Op::updateConsummerProducer() {
else {
if (!mScheduler) {
// Lazy initialization
mScheduler = std::make_shared<SequentialScheduler>(mGraph);
mScheduler = std::make_shared<SequentialScheduler>(mGraph, mUpperNode.lock());
}
// TODO: check that generateScheduling() can be called multiple time to iteratively update the schedule.
// It could be a good idea to unify updateConsummerProducer() and generateScheduling() into a "updateScheduling()"
mScheduler->generateScheduling();
......@@ -86,7 +85,7 @@ void Aidge::MetaOperator_Op::forward() {
// Lazy initialization
// TODO: should we assert that a scheduler already exists at this point?
// => should be created in updateConsummerProducer()
mScheduler = std::make_shared<SequentialScheduler>(mGraph);
mScheduler = std::make_shared<SequentialScheduler>(mGraph, mUpperNode.lock());
mScheduler->generateScheduling();
}
......
......@@ -24,6 +24,7 @@
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/Memorize.hpp"
#include "aidge/operator/MetaOperator.hpp"
void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") {
putchar('[');
......@@ -68,6 +69,13 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
std::map<std::shared_ptr<Node>, std::string> namePtrTable;
if (verbose) namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
// Still consumers are consumers that were run by can still consume data.
// They must be run AFTER the remaining consumer to ensure a non-greedy
// producers-consumers model!
std::set<std::shared_ptr<Node>> stillConsumers;
mStaticSchedule.push_back(std::vector<std::shared_ptr<Node>>());
do {
// From the current consumers list, check if any prior nodes are needed.
// If for a given node, only parent producers (at any depth) are needed
......@@ -121,7 +129,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
// Make producers generate the required data
for (const auto& requiredProducer : requiredProducers) {
requiredProducer->getOperator()->updateConsummerProducer();
mStaticSchedule.push_back(requiredProducer);
mStaticSchedule.back().push_back(requiredProducer);
}
// find runnable consumers
......@@ -150,23 +158,19 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
}
bool isRunnable = true;
IOIndex_t inputIdx = 0; // FIXME: handle this correctly
// Check every input has got enought data to run
for (const auto& consumerParent : consumer->inputs()) {
if (consumerParent.first &&
(consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) >
consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)) {
if (verbose) printf(" not runnable: C%zu + R%zu > P%zu\n",
for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) {
if (/*consumer->getOperator()->getNbRequiredData(inputIdx) > 0
&& */(consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) >
getNbAvailableData(consumer, inputIdx)) {
if (verbose) printf(" not runnable: C%zu + R%zu > P%zu for input #%d\n",
consumer->getOperator()->getNbConsumedData(inputIdx),
consumer->getOperator()->getNbRequiredData(inputIdx),
consumerParent.first->getOperator()->getNbProducedData(consumerParent.second));
getNbAvailableData(consumer, inputIdx), inputIdx);
// not enough data to run
isRunnable = false;
break;
}
++inputIdx;
}
if (isRunnable) {
......@@ -178,11 +182,16 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
for (const auto& runnable : runnableConsumers) {
if (verbose) printf("Runnable: %s\n", namePtrTable[runnable].c_str());
runnable->getOperator()->updateConsummerProducer();
mStaticSchedule.push_back(runnable);
mStaticSchedule.back().push_back(runnable);
}
if (runnableConsumers.empty()) {
frozenConsumers.push_back(consumers);
if (std::find(frozenConsumers.begin(), frozenConsumers.end(), consumers) == frozenConsumers.end()) {
frozenConsumers.push_back(consumers);
}
else {
break;
}
}
else {
frozenConsumers.clear();
......@@ -209,23 +218,19 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
printf("%zu", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1));
printf("\n");
}
bool isStillConsumer = false;
IOIndex_t inputIdx = 0; // FIXME: handle this correctly
// should we check input or dataInput ?
for (const auto& consumerParent : consumer->inputs()) {
if (consumerParent.first &&
consumer->getOperator()->getNbConsumedData(inputIdx) <
consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)) {
if (verbose) printf(" still consumer: C%zu < P%zu\n",
bool isStillConsumer = false;
for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) {
if (consumer->getOperator()->getNbConsumedData(inputIdx) <
getNbAvailableData(consumer, inputIdx)) {
if (verbose) printf(" still consumer: C%zu < P%zu for input #%d\n",
consumer->getOperator()->getNbConsumedData(inputIdx),
consumerParent.first->getOperator()->getNbProducedData(consumerParent.second));
getNbAvailableData(consumer, inputIdx), inputIdx);
// there is still data to consume
isStillConsumer = true;
break;
}
++inputIdx;
}
for (IOIndex_t outId = 0; outId < consumer->nbOutputs(); ++outId) {
......@@ -234,21 +239,34 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
// make sure consumer is also a producer
producers.insert(consumer);
const auto& childs = consumer->getChildren();
consumers.insert(childs.begin(), childs.end());
const auto& newConsumers = getConsumers({consumer});
consumers.insert(newConsumers.cbegin(), newConsumers.cend());
break;
}
}
if (!isStillConsumer) {
if (verbose) printf(" no more consumer\n");
// consumer is no longer a consumer, only a producer
if (runnableConsumers.find(consumer) != runnableConsumers.end()) {
// If consumer was run, remove it from the consumers list for
// now
consumers.erase(consumer);
if (isStillConsumer) {
// If there is still data to consume, the consumer will be
// run AFTER the other remaining consumers
// (= non-greedy consumers)
stillConsumers.insert(consumer);
}
}
}
// If there is no more consumers, swap with possible "still consumers"
// This ensures that the "non-greedy" consumer behavior
if (consumers.empty()) {
consumers.swap(stillConsumers);
stillConsumers.clear();
}
if (verbose) printf("********************\n");
} while (!consumers.empty() && (frozenConsumers.empty() || std::find(frozenConsumers.begin(), frozenConsumers.end(), consumers) == frozenConsumers.end()));
} while (!consumers.empty());
if (verbose) {
if (!consumers.empty()) {
......@@ -270,14 +288,11 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
this->generateScheduling(verbose);
}
// Clear previous scheduling results
mScheduling.clear();
std::map<std::shared_ptr<Node>, std::string> namePtrTable;
if (verbose) namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
int cpt = 0;
for (const auto& runnable : mStaticSchedule) {
size_t cpt = 0;
for (const auto& runnable : mStaticSchedule.at(mStaticScheduleStep)) {
if (verbose)
printf("run: %s\n",
namePtrTable[runnable].c_str());
......@@ -292,6 +307,8 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
}
if (!verbose) drawProgressBar(1.0, 50, " ");
printf("\n");
++mStaticScheduleStep;
}
void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const {
......@@ -321,12 +338,58 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers(
for (const auto& producer : producers) {
const auto& childs = producer->getChildren();
consumers.insert(childs.begin(), childs.end());
for (const auto& child : childs) {
// Do not schedule childs outside current graph!
if (mGraphView->inView(child)) {
consumers.insert(child);
}
}
}
return consumers;
}
Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const {
const auto parent = node->inputs()[inputIdx];
if (parent.first) {
// Parent is connected, everything if fine!
return parent.first->getOperator()->getNbProducedData(parent.second);
}
else if (std::shared_ptr<Node> upperNode = mUpperNode.lock()) {
// We are inside an upper operator (for instance a MetaOperator)
// We need to connect the "local" producer-consumer model to the upper
// one, by mapping local node inputs to the upper node inputs.
IOIndex_t nodeInputIdx = 0;
for (const auto& input : mGraphView->getOrderedInputs()) {
if (input.first == node) {
// Current node is an input
const auto upperInput = upperNode->inputs()[nodeInputIdx];
if (upperInput.first) {
return upperInput.first->getOperator()->getNbProducedData(upperInput.second);
}
}
++nodeInputIdx;
}
}
// Otherwise, two cases:
if (node->getOperator()->getRawInput(inputIdx)) {
// Input is not connected but a valid tensor exists
// => This means data was fed manually to the input, without a Producer
// In this case, we assume a single-use data (unlike a Producer, which
// keep producing the data each time it is needed).
fmt::print("No producer node attached to input#{} for node {} ({})\n", inputIdx, node->name(), node->type());
return std::static_pointer_cast<Tensor>(node->getOperator()->getRawInput(inputIdx))->size();
}
else {
// Input is not connected, this is an error
AIDGE_THROW_OR_ABORT(std::runtime_error, "Missing input#{} for node {} ({})\n", inputIdx, node->name(), node->type());
}
return 0;
}
Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::getPriorProducersConsumers(
const std::shared_ptr<Node>& node) const
{
......@@ -338,6 +401,11 @@ Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::
(node->getOperator()->getNbConsumedData(inputIdx) + node->getOperator()->getNbRequiredData(inputIdx)) >
parent.first->getOperator()->getNbProducedData(parent.second))
{
if (!mGraphView->inView(parent.first)) {
// Do not schedule prior outside the current graph!
return PriorProducersConsumers();
}
if (parent.first->type() == Producer_Op::Type) {
prior.requiredProducers.insert(parent.first);
prior.priorConsumers.insert(node);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment