Newer
Older
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/scheduler/Scheduler.hpp"
#include <chrono>
#include <memory>
#include <set>
#include <string>
Olivier BICHLER
committed
#include <fmt/ranges.h>
Olivier BICHLER
committed
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
#include "aidge/recipes/GraphViewHelper.hpp"
#include "aidge/operator/Memorize.hpp"
#include "aidge/operator/MetaOperator.hpp"
void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") {
putchar('[');
int pos = static_cast<int>(barWidth * progress);
for (int i = 0; i < barWidth; ++i) {
if (i <= pos)
putchar('#');
else
putchar(' ');
}
fmt::print("] {}% | {}\r", static_cast<int>(progress * 100), additionalInfo);
void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
// TODO: For loop on the list of node to run
// run sequencially every runnable consumers once
// TODO: handle memory allocation in scheduler
// TODO: optimize memory usage
// 1) Setup initial consumers list:
// It is the list of input nodes
std::set<std::shared_ptr<Node>> consumers = mGraphView->inputNodes();
// Plus the list of nodes inside the graph connected to an inner producer
std::set<std::shared_ptr<Node>> producers;
for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) {
const auto producersConsumers = getConsumers(producers);
consumers.insert(producersConsumers.begin(), producersConsumers.end());
Olivier BICHLER
committed
std::map<std::shared_ptr<Node>, std::string> namePtrTable;
if (verbose) namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
// Still consumers are consumers that were run by can still consume data.
// They must be run AFTER the remaining consumer to ensure a non-greedy
// producers-consumers model!
std::set<std::shared_ptr<Node>> stillConsumers;
mStaticSchedule.push_back(std::vector<std::shared_ptr<Node>>());
// 2) From the current consumers list, check if any prior consumer node
// is needed. A prior will generally be required for any node consuming
// parameters (weights and bias) that is not an input node.
// If for a given node, only parent producers (at any depth) are needed
// to satisfy its required data, it becomes a prior.
// If the prior node is a producer, it is added to the list of required
// producers.
// If the prior node is of another type, it replaces the initial consumer
// in the new priorConsumers list. The initial consumer will become
// again a consumer later, by construction.
if (verbose) fmt::print("List of consumers with their priors:\n");
std::set<std::shared_ptr<Node>> priorConsumers;

Grégoire Kubler
committed
mPriorCache.clear();
for (const auto& consumer : consumers) {
if (verbose) {
fmt::print("\t- consumer: ");
fmt::print(fg(fmt::color::orange), namePtrTable[consumer]);
fmt::print("\n");
const auto& prior = getPriorProducersConsumers(consumer);
if (prior.isPrior) {
if (verbose) {
Olivier BICHLER
committed
std::vector<std::string> requiredProducersName;
std::transform(prior.requiredProducers.begin(), prior.requiredProducers.end(),
std::back_inserter(requiredProducersName),
Olivier BICHLER
committed
fmt::print("\t\trequired producers: {}\n", requiredProducersName);
std::vector<std::string> priorConsumersName;
std::transform(prior.priorConsumers.begin(), prior.priorConsumers.end(),
std::back_inserter(priorConsumersName),
Olivier BICHLER
committed
fmt::print("\t\tprior consumers: {}\n", priorConsumersName);
requiredProducers.insert(prior.requiredProducers.cbegin(), prior.requiredProducers.cend());
priorConsumers.insert(prior.priorConsumers.cbegin(), prior.priorConsumers.cend());
}
else {
priorConsumers.insert(consumer);
// 3) Prior consumers replace the initial consumers list.
// By construction, initial consumers will necessarily become consumers
// again later.
consumers.swap(priorConsumers);
// 4) Make producers generate the required data.
// Producers are special nodes that generate data on demand.
for (const auto& requiredProducer : requiredProducers) {
requiredProducer->getOperator()->updateConsummerProducer();
mStaticSchedule.back().push_back(requiredProducer);
// 5) Find runnable consumers.
// A consumer is runnable if the required data is available for all of
// its inputs. At this point, not all consumers are necessarily
// runnable because some may depend on the execution of others (when
// there is multiple successive priors for example).
for (const auto& consumer : consumers) {
if (verbose) {
fmt::print("\t- consumer: ");
fmt::print(fg(fmt::color::orange), namePtrTable[consumer]);
fmt::print("\n\t\tC/R:\t");
for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) {
fmt::print("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId),
fmt::print("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1),
consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1));
for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) {
fmt::print("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId));
fmt::print("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1));
fmt::print("\n");
for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) {
if (/*consumer->getOperator()->getNbRequiredData(inputIdx) > 0
&& */(consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) >
getNbAvailableData(consumer, inputIdx)) {
if (verbose) fmt::print(" not runnable: C{} + R{} > P{} for input #{}\n",
consumer->getOperator()->getNbConsumedData(inputIdx),
consumer->getOperator()->getNbRequiredData(inputIdx),
getNbAvailableData(consumer, inputIdx), inputIdx);
// not enough data to run
isRunnable = false;
break;
}
}
if (isRunnable) {
runnableConsumers.insert(consumer);
}
}
// 5) If not consumer is runnable, it is a stop condition!
if (runnableConsumers.empty()) {
// No consumer is runnable: some required data is missing for all of
// them. There is two possibilities:
// - At least one required data source is exhausted, which may be
// an expected stop condition.
// - There is a deadlock between consumers, if some one is waiting
// for data from the other and reciprocally.
break;
}
// 6) Push runnable consumers in the list of nodes to run and update the
// consumer producer system.
// At this point, simultaneously runnable consumers have no data
// dependency and could be run in parallel!
if (verbose) fmt::print("Runnable: {}\n", namePtrTable[runnable]);
runnable->getOperator()->updateConsummerProducer();
mStaticSchedule.back().push_back(runnable);
// 7) Update consumers list
if (verbose) fmt::print("Updating producer and consumer lists...\n");
for (const auto& consumer : runnableConsumers) {
fmt::print("\t- consumer: {}\n\t\tC/R:\t",
namePtrTable[consumer]);
for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) {
fmt::print("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId),
fmt::print("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1),
consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1));
for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) {
fmt::print("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId));
fmt::print("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1));
fmt::print("\n");
// 7.1) If the current consumer has still data to consume, it will
// be put back in the consumers list once the remaining consumers
// have been exhausted.
bool isStillConsumer = false;
for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) {
if (consumer->getOperator()->getNbConsumedData(inputIdx) <
getNbAvailableData(consumer, inputIdx)) {
if (verbose) fmt::print(" still consumer: C{} < P{} for input #{}\n",
getNbAvailableData(consumer, inputIdx), inputIdx);
// there is still data to consume
isStillConsumer = true;
break;
}
}
// 7.2) If the current consumer becomes a producer for other nodes,
// its childs become consumers.
bool isProducer = false;
for (IOIndex_t outId = 0; outId < consumer->nbOutputs(); ++outId) {
for (const auto& child : consumer->getChildren(outId)) {
if (child) {
IOIndex_t inputIdx = 0;
for (const auto& childParent : child->getParents()) {
if (childParent == consumer) {
if (consumer->getOperator()->getNbProducedData(outId) > child->getOperator()->getNbConsumedData(inputIdx)) {
isProducer = true;
}
}
++inputIdx;
}
}
}
/*
if (consumer->getOperator()->getNbProducedData(outId) > 0) {
// make sure consumer is also a producer
producers.insert(consumer);
const auto& newConsumers = getConsumers({consumer});
consumers.insert(newConsumers.cbegin(), newConsumers.cend());
consumers.erase(consumer);
if (isProducer) {
// make sure consumer is also a producer
producers.insert(consumer);
const auto& newConsumers = getConsumers({consumer});
consumers.insert(newConsumers.cbegin(), newConsumers.cend());
}
if (isStillConsumer) {
// If there is still data to consume, the consumer will be
// run AFTER the other remaining consumers
// (= non-greedy consumers)
stillConsumers.insert(consumer);
// 8) If there is no more consumers, swap with possible "still consumers"
// This ensures that the "non-greedy" consumer behavior
if (consumers.empty()) {
consumers.swap(stillConsumers);
stillConsumers.clear();
}
} while (!consumers.empty());
fmt::print("/!\\ Remaining consumers: possible dead-lock\n");
fmt::print("********************\n");
void Aidge::SequentialScheduler::resetScheduling() {
for (auto node : mGraphView->getNodes()) {
node->getOperator()->resetConsummerProducer();
}
mStaticSchedule.clear();
mStaticScheduleStep = 0;
mScheduling.clear();
}
/**
* This version is a simplified version without special handling of concatenation.
*/
Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducers, bool wrapAroundBuffer) const {
MemoryManager memManager;
for (const auto& shedule : mStaticSchedule) {
for (const auto& node : shedule) {
if (!incProducers && node->type() == Producer_Op::Type) {
memManager.releaseDependencies(node);
continue;
}
const auto childs = node->getChildren();
AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane;
// Allocate a memory plane for each node's output
for (IOIndex_t outputIdx = 0; outputIdx < node->nbOutputs(); ++outputIdx) {
const size_t requiredSize = op->getRequiredMemory(outputIdx, {});
// By default, specifies a fully monolithic memory block
size_t size = requiredSize;
size_t stride = 0;
size_t length = 1;
size_t count = 1;
if (op->getOutput(outputIdx) && op->getOutput(outputIdx)->dims().size() > 3) {
// If it is possible, assume a NCHW layout
size = op->getOutput(outputIdx)->dims().end()[-3];
stride = size;
length = op->getOutput(outputIdx)->dims().end()[-1];
count = op->getOutput(outputIdx)->dims().end()[-2];
}
// Check if wrap around buffer is possible for this node
// (re-using previous node outputs memory for this node outputs).
// => only if this node is the only child of its parent(s)
size_t wrapAroundSize = 0;
size_t wrapAroundExtra = 0;
wrapAroundMemPlane.push_back(nullptr);
// Select the best parent among all allocable nodes for
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
// reallocation, which is the one with most memory (in order
// to minimize the reallocation size).
IOIndex_t inputIdx = 0;
for (const auto& parent : node->dataInputs()) {
if (parent.first && parent.first->getChildren(parent.second).size() == 1
// there might be no existing plane if the parent was
// not yet scheduled (because it may be a recurrent connection)
&& memManager.getNbPlanes(parent.first) >= parent.first->nbOutputs()
// memSpace should not be already released
&& memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second].memSpace->released == -1)
{
const bool isWrappable = (op->getNbRequiredProtected(inputIdx) < op->getNbRequiredData(inputIdx));
const MemoryManager::MemoryPlane& memPlane = memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second];
if (isWrappable || !memManager.isWrapAround(
memPlane.memSpace,
memPlane.getFinalOffset()
- memPlane.memSpace->offset,
requiredSize))
{
if (memPlane.getSize() > wrapAroundSize + op->getNbRequiredProtected(inputIdx)
&& std::find(wrapAroundMemPlane.begin(), wrapAroundMemPlane.end(), &memPlane) == wrapAroundMemPlane.end())
{
wrapAroundSize = memPlane.getSize() - op->getNbRequiredProtected(inputIdx);
if (requiredSize > wrapAroundSize) {
wrapAroundExtra = requiredSize - wrapAroundSize;
}
wrapAroundMemPlane[outputIdx] = &memPlane;
}
if (wrapAroundExtra == 0) {
break;
}
}
}
++inputIdx;
}
// MemoryPlane to (re)use
const MemoryManager::MemoryPlane& memPlane
= (wrapAroundBuffer && wrapAroundSize > 0)
? (*wrapAroundMemPlane[outputIdx]) :
memManager.allocate(requiredSize, childs, stride, length, count);
if (wrapAroundBuffer && wrapAroundSize > 0) {
memManager.reallocate(memPlane,
node, 0,
requiredSize, true, wrapAroundExtra, childs, stride, length, count);
}
else {
memManager.reallocate(memPlane.memSpace,
node, memPlane.offset,
requiredSize, false, 0, childs, stride, length, count);
}
}
memManager.releaseDependencies(node);
memManager.tick();
}
}
return memManager;
}
Thibault Allenet
committed
void Aidge::SequentialScheduler::connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data){
// This version of connect inputs only connects tensor inputs in input data producers.
auto inputNodes = mGraphView->getOrderedInputs();
// Assert that the number of input data producers corresponds to the number of data input
assert(data.size() == inputNodes.size() && "Scheduler connectInput error - Inconsistent number of graph inputs and inputs passed to the graph");
Thibault Allenet
committed
for (std::size_t i = 0; i < data.size(); ++i){
// TODO : maybe shallow copy instead of deepcopy
inputNodes[i].first->getOperator()->setInput(inputNodes[i].second, data[i]);
}
}
void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::vector<std::shared_ptr<Aidge::Tensor>> data) {
Thibault Allenet
committed
// Collect all data input of the graph (that are producers)
if (!data.empty()){
connectInputs(data);
}
if (forwardDims) {mGraphView->forwardDims(); }
// Generate scheduling *only if empty*
// If scheduling was already generated (in one or several steps, i.e. one or
// several successive call to generateScheduling()), do not generate it twice
if (mStaticSchedule.empty()) {
const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
size_t cpt = 0;
for (const auto& runnable : mStaticSchedule.at(mStaticScheduleStep)) {
if (verbose)
drawProgressBar(static_cast<float>(cpt) / static_cast<float>(mStaticSchedule.size()), 50,
const auto tStart = std::chrono::high_resolution_clock::now();
runnable->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
mScheduling.push_back(SchedulingElement(runnable, tStart, tEnd));
++mStaticScheduleStep;
if (mStaticScheduleStep == mStaticSchedule.size()) {
mStaticScheduleStep = 0;
}
void Aidge::SequentialScheduler::backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instanciateGrad, bool verbose) {
// create ad set Grad values
if (instanciateGrad) { compile_gradient(mGraphView); }
const auto& ordered_outputs = mGraphView->getOrderedOutputs();
AIDGE_ASSERT(ordered_outputs.size() == data.size(), "You must provide the \
right number of data objects to run the backward function. \
{} outputs detected for the current GraphView when {} were \
provided.", ordered_outputs.size(), data.size());
for (std::size_t i = 0; i < ordered_outputs.size(); ++i) {
const std::shared_ptr<OperatorTensor> op_ = std::dynamic_pointer_cast<OperatorTensor>(ordered_outputs[i].first->getOperator());
const std::shared_ptr<Tensor> t_grad = op_->getOutput(ordered_outputs[i].second)->grad();
AIDGE_ASSERT(data[i]->dims() == t_grad->dims(), "Wrong gradient size.");
*t_grad = data[i]->clone();
}
// Generate scheduling *only if empty*
// If scheduling was already generated (in one or several steps, i.e. one or
// several successive call to generateScheduling()), do not generate it twice
if (mStaticSchedule.empty()) {
this->generateScheduling();
}
// map of node <-> info to display with verbose
const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
// Clear previous scheduling results
mScheduling.clear();
std::size_t cpt = 0;
// run scheduled operators in reverse order
const auto& runnableList = mStaticSchedule.at(mStaticScheduleStep);
for (auto runnable = runnableList.crbegin(); runnable != runnableList.crend(); ++runnable) {
fmt::print("run: {}\n", namePtrTable.at(*runnable));
else
drawProgressBar(static_cast<float>(cpt) / static_cast<float>(mStaticSchedule.size()), 50,
(std::string("running ") + namePtrTable.at(*runnable)));
const auto tStart = std::chrono::high_resolution_clock::now();
(*runnable)->backward();
const auto tEnd = std::chrono::high_resolution_clock::now();
mScheduling.push_back(SchedulingElement(*runnable, tStart, tEnd));
cpt++;
}
if (!verbose) drawProgressBar(1.0, 50, " ");
fmt::print("\n");
++mStaticScheduleStep;
if (mStaticScheduleStep == mStaticSchedule.size()) {
mStaticScheduleStep = 0;
}
void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const {
auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose);
if (!fp) {
AIDGE_THROW_OR_ABORT(std::runtime_error,
"Could not create scheduling diagram log file: {}", fileName + ".mmd");
}
fmt::print(fp.get(), "gantt\ndateFormat x\naxisFormat %Q µs\n\n");
Olivier BICHLER
committed
const std::map<std::shared_ptr<Node>, std::string> namePtrTable
= mGraphView->getRankedNodesName("{0} ({1}#{3})");
const auto globalStart = mScheduling[0].start;
for (const auto& element : mScheduling) {
auto name = namePtrTable.at(element.node);
// Mermaid does not allow : character in task title
std::replace(name.begin(), name.end(), ':', '_');
std::chrono::duration_cast<std::chrono::microseconds>(element.start - globalStart).count(),
std::chrono::duration_cast<std::chrono::microseconds>(element.end - globalStart).count());
}
}
}
std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers(
const std::set<std::shared_ptr<Node>>& producers) const {
std::set<std::shared_ptr<Node>> consumers;
for (const auto& producer : producers) {
const auto& childs = producer->getChildren();
for (const auto& child : childs) {
// Do not schedule childs outside current graph!
if (mGraphView->inView(child)) {
consumers.insert(child);
}
}

Cyril Moineau
committed
}
Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const {
const auto parent = node->inputs()[inputIdx];
if (parent.first) {
// Parent is connected, everything if fine!
return parent.first->getOperator()->getNbProducedData(parent.second);
}
else if (std::shared_ptr<Node> upperNode = mUpperNode.lock()) {
// We are inside an upper operator (for instance a MetaOperator)
// We need to connect the "local" producer-consumer model to the upper
// one, by mapping local node inputs to the upper node inputs.
IOIndex_t nodeInputIdx = 0;
for (const auto& input : mGraphView->getOrderedInputs()) {
if (input.first == node) {
// Current node is an input
const auto upperInput = upperNode->inputs()[nodeInputIdx];
if (upperInput.first) {
return upperInput.first->getOperator()->getNbProducedData(upperInput.second);
}
}
// Otherwise, two cases:
if (node->getOperator()->getRawInput(inputIdx)) {
// Input is not connected but a valid tensor exists
// => This means data was fed manually to the input, without a Producer
// In this case, we assume a single-use data (unlike a Producer, which
// keep producing the data each time it is needed).
fmt::print("No producer node attached to input#{} for node {} ({})\n", inputIdx, node->name(), node->type());
return std::static_pointer_cast<Tensor>(node->getOperator()->getRawInput(inputIdx))->size();
}
else {
// Input is not connected, this is an error
AIDGE_THROW_OR_ABORT(std::runtime_error, "Missing input#{} for node {} ({})\n", inputIdx, node->name(), node->type());
}
return 0;
}
Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::getPriorProducersConsumers(
const std::shared_ptr<Node>& node) const
{

Grégoire Kubler
committed
const auto priorCache = mPriorCache.find(node);
if (priorCache != mPriorCache.end()) {
return priorCache->second;
}
PriorProducersConsumers prior;
IOIndex_t inputIdx = 0;
for (const auto& parent : node->inputs()) {
if (parent.first &&
(node->getOperator()->getNbConsumedData(inputIdx) + node->getOperator()->getNbRequiredData(inputIdx)) >
parent.first->getOperator()->getNbProducedData(parent.second))
{
if (!mGraphView->inView(parent.first)) {
// Do not schedule prior outside the current graph!
return PriorProducersConsumers();
}
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
if (parent.first->type() == Producer_Op::Type) {
prior.requiredProducers.insert(parent.first);
prior.priorConsumers.insert(node);
}
else if (parent.first->type() == Memorize_Op::Type) {
// Break cycles
return PriorProducersConsumers();
}
else {
const auto& parentPrior = getPriorProducersConsumers(parent.first);
if (!parentPrior.isPrior) {
return PriorProducersConsumers();
}
else {
prior.requiredProducers.insert(parentPrior.requiredProducers.cbegin(), parentPrior.requiredProducers.cend());
prior.priorConsumers.insert(parentPrior.priorConsumers.cbegin(), parentPrior.priorConsumers.cend());
}
}
}
++inputIdx;
}
prior.isPrior = true;
if (prior.priorConsumers.empty()) {
prior.priorConsumers.insert(node);
}

Grégoire Kubler
committed
mPriorCache.insert(std::make_pair(node, prior));
return prior;
}