Newer
Older
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/scheduler/Scheduler.hpp"
#include <chrono>
#include <memory>
#include <set>
#include <string>
Olivier BICHLER
committed
#include <fmt/ranges.h>
Olivier BICHLER
committed
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
#include "aidge/operator/Memorize.hpp"
#include "aidge/operator/MetaOperator.hpp"
void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") {
putchar('[');
int pos = static_cast<int>(barWidth * progress);
for (int i = 0; i < barWidth; ++i) {
if (i <= pos)
putchar('#');
else
putchar(' ');
}
fmt::print("] {}% | {}\r", static_cast<int>(progress * 100), additionalInfo);
void Aidge::SequentialScheduler::generateScheduling() {
auto schedule = generateBaseScheduling();
generateEarlyLateScheduling(schedule);
mStaticSchedule.push_back(schedule);
}
std::vector<std::shared_ptr<Aidge::SequentialScheduler::StaticSchedulingElement>> Aidge::SequentialScheduler::generateBaseScheduling() const {
// TODO: For loop on the list of node to run
// run sequencially every runnable consumers once
// TODO: handle memory allocation in scheduler
// TODO: optimize memory usage
// 1) Setup initial consumers list:
// It is the list of input nodes
std::set<std::shared_ptr<Node>> consumers = mGraphView->inputNodes();
// Plus the list of nodes inside the graph connected to an inner producer
std::set<std::shared_ptr<Node>> producers;
for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) {
const auto producersConsumers = getConsumers(producers);
consumers.insert(producersConsumers.begin(), producersConsumers.end());
const std::map<std::shared_ptr<Node>, std::string> namePtrTable
= mGraphView->getRankedNodesName("{0} ({1}#{3})");
// Still consumers are consumers that were run by can still consume data.
// They must be run AFTER the remaining consumer to ensure a non-greedy
// producers-consumers model!
std::set<std::shared_ptr<Node>> stillConsumers;
std::vector<std::shared_ptr<StaticSchedulingElement>> schedule;
// 2) From the current consumers list, check if any prior consumer node
// is needed. A prior will generally be required for any node consuming
// parameters (weights and bias) that is not an input node.
// If for a given node, only parent producers (at any depth) are needed
// to satisfy its required data, it becomes a prior.
// If the prior node is a producer, it is added to the list of required
// producers.
// If the prior node is of another type, it replaces the initial consumer
// in the new priorConsumers list. The initial consumer will become
// again a consumer later, by construction.
Log::debug("List of consumers with their priors:");
std::set<std::shared_ptr<Node>> priorConsumers;
Log::debug("\t- consumer: {}", fmt::styled(namePtrTable.at(consumer), fg(fmt::color::orange)));
const auto& prior = getPriorProducersConsumers(consumer);
if (prior.isPrior) {
std::vector<std::string> requiredProducersName;
std::transform(prior.requiredProducers.begin(), prior.requiredProducers.end(),
std::back_inserter(requiredProducersName),
[&namePtrTable](auto val){ return namePtrTable.at(val); });
Log::debug("\t\trequired producers: {}", requiredProducersName);
std::vector<std::string> priorConsumersName;
std::transform(prior.priorConsumers.begin(), prior.priorConsumers.end(),
std::back_inserter(priorConsumersName),
[&namePtrTable](auto val){ return namePtrTable.at(val); });
Log::debug("\t\tprior consumers: {}", priorConsumersName);
requiredProducers.insert(prior.requiredProducers.cbegin(), prior.requiredProducers.cend());
priorConsumers.insert(prior.priorConsumers.cbegin(), prior.priorConsumers.cend());
}
else {
priorConsumers.insert(consumer);
// 3) Prior consumers replace the initial consumers list.
// By construction, initial consumers will necessarily become consumers
// again later.
consumers.swap(priorConsumers);
// 4) Make producers generate the required data.
// Producers are special nodes that generate data on demand.
for (const auto& requiredProducer : requiredProducers) {
requiredProducer->getOperator()->updateConsummerProducer();
schedule.push_back(std::make_shared<StaticSchedulingElement>(requiredProducer));
// 5) Find runnable consumers.
// A consumer is runnable if the required data is available for all of
// its inputs. At this point, not all consumers are necessarily
// runnable because some may depend on the execution of others (when
// there is multiple successive priors for example).
Log::debug("Updated list of consumers:");
Log::debug("\t- consumer: {}", fmt::styled(namePtrTable.at(consumer), fg(fmt::color::orange)));
std::string crLog = "\t\tC/R:\t";
for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) {
crLog += fmt::format("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId),
consumer->getOperator()->getNbRequiredData(inId));
crLog += fmt::format("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1),
consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1));
Log::debug("{}", crLog);
std::string pLog = "\t\tP:\t";
for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) {
pLog += fmt::format("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId));
}
pLog += fmt::format("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1));
Log::debug("{}", pLog);
for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) {
if (/*consumer->getOperator()->getNbRequiredData(inputIdx) > 0
&& */(consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) >
getNbAvailableData(consumer, inputIdx)) {
Log::debug(" not runnable: C{} + R{} > P{} for input #{}",
consumer->getOperator()->getNbConsumedData(inputIdx),
consumer->getOperator()->getNbRequiredData(inputIdx),
getNbAvailableData(consumer, inputIdx), inputIdx);
// not enough data to run
isRunnable = false;
break;
}
}
if (isRunnable) {
runnableConsumers.insert(consumer);
}
}
// 5) If not consumer is runnable, it is a stop condition!
if (runnableConsumers.empty()) {
// No consumer is runnable: some required data is missing for all of
// them. There is two possibilities:
// - At least one required data source is exhausted, which may be
// an expected stop condition.
// - There is a deadlock between consumers, if some one is waiting
// for data from the other and reciprocally.
break;
}
// 6) Push runnable consumers in the list of nodes to run and update the
// consumer producer system.
// At this point, simultaneously runnable consumers have no data
// dependency and could be run in parallel!
Log::debug("Runnable: {}", namePtrTable.at(runnable));
runnable->getOperator()->updateConsummerProducer();
schedule.push_back(std::make_shared<StaticSchedulingElement>(runnable));
// 7) Update consumers list
Log::debug("Updating producer and consumer lists...");
for (const auto& consumer : runnableConsumers) {
Log::debug("\t- consumer: {}", fmt::styled(namePtrTable.at(consumer), fg(fmt::color::orange)));
std::string crLog = "\t\tC/R:\t";
for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) {
crLog += fmt::format("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId),
consumer->getOperator()->getNbRequiredData(inId));
}
crLog += fmt::format("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1),
consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1));
Log::debug("{}", crLog);
std::string pLog = "\t\tP:\t";
for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) {
pLog += fmt::format("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId));
pLog += fmt::format("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1));
Log::debug("{}", pLog);
// 7.1) If the current consumer has still data to consume, it will
// be put back in the consumers list once the remaining consumers
// have been exhausted.
bool isStillConsumer = false;
for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) {
if (consumer->getOperator()->getNbConsumedData(inputIdx) <
getNbAvailableData(consumer, inputIdx)) {
Log::debug(" still consumer: C{} < P{} for input #{}",
getNbAvailableData(consumer, inputIdx), inputIdx);
// there is still data to consume
isStillConsumer = true;
break;
}
}
// 7.2) If the current consumer becomes a producer for other nodes,
// its childs become consumers.
bool isProducer = false;
for (IOIndex_t outId = 0; outId < consumer->nbOutputs(); ++outId) {
for (const auto& child : consumer->getChildren(outId)) {
if (child) {
IOIndex_t inputIdx = 0;
for (const auto& childParent : child->getParents()) {
if (childParent == consumer) {
if (consumer->getOperator()->getNbProducedData(outId) > child->getOperator()->getNbConsumedData(inputIdx)) {
isProducer = true;
}
}
++inputIdx;
}
}
}
/*
if (consumer->getOperator()->getNbProducedData(outId) > 0) {
// make sure consumer is also a producer
producers.insert(consumer);
const auto& newConsumers = getConsumers({consumer});
consumers.insert(newConsumers.cbegin(), newConsumers.cend());
consumers.erase(consumer);
if (isProducer) {
// make sure consumer is also a producer
producers.insert(consumer);
const auto& newConsumers = getConsumers({consumer});
consumers.insert(newConsumers.cbegin(), newConsumers.cend());
}
if (isStillConsumer) {
// If there is still data to consume, the consumer will be
// run AFTER the other remaining consumers
// (= non-greedy consumers)
stillConsumers.insert(consumer);
// 8) If there is no more consumers, swap with possible "still consumers"
// This ensures that the "non-greedy" consumer behavior
if (consumers.empty()) {
consumers.swap(stillConsumers);
stillConsumers.clear();
}
} while (!consumers.empty());
if (!consumers.empty()) {
Log::warn("Remaining consumers: possible dead-lock");
void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const {
// Calculate early (logical) start
for (size_t elt = 0; elt < schedule.size(); ++elt) {
const auto itNode = std::find_if(schedule.rend() - elt, schedule.rend(),
[node](const auto& v) { return (v->node == node); });
// Node can be run the earliest just after it was run the last time!
early = (*itNode)->early + 1;
(*itNode)->earlierThan.push_back(schedule[elt]);
// Node can be run the earliest just after its latest parent was run
for (const auto& parent : node->getParents()) {
// Find parent node latest scheduled position
const auto it = std::find_if(schedule.rend() - elt, schedule.rend(),
[parent](const auto& v) { return (v->node == parent); });
if (it != schedule.rend()) {
const size_t step = std::distance(schedule.begin(), it.base()) - 1;
early = std::max(early, schedule[step]->early + 1);
schedule[step]->earlierThan.push_back(schedule[elt]);
latest = std::max(latest, early);
// Calculate late (logical) start
for (size_t elt = schedule.size(); elt-- != 0; ) {
const auto itNode = std::find_if(schedule.begin() + elt + 1, schedule.end(),
[node](const auto& v) { return (v->node == node); });
// Node can be run the latest just before it is run the next time!
// Node can be run the latest just before its earliest child is run
for (const auto& child : node->getChildren()) {
// Find child node earliest scheduled position
const auto it = std::find_if(schedule.begin() + elt + 1, schedule.end(),
[child](const auto& v) { return (v->node == child); });
if (it != schedule.end()) {
const size_t step = std::distance(schedule.begin(), it);
late = std::min(late, schedule[step]->late - 1);
void Aidge::SequentialScheduler::resetScheduling() {
for (auto node : mGraphView->getNodes()) {
node->getOperator()->resetConsummerProducer();
}
mStaticSchedule.clear();
mStaticScheduleStep = 0;
mScheduling.clear();
}
/**
* This version is a simplified version without special handling of concatenation.
*/
Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducers, bool wrapAroundBuffer) const {
MemoryManager memManager;
for (size_t step = 0; step < mStaticSchedule.size(); ++step) {
for (const auto& node : getStaticScheduling(step)) {
if (!incProducers && node->type() == Producer_Op::Type) {
memManager.releaseDependencies(node);
continue;
}
const auto childs = node->getChildren();
AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane;
// Allocate a memory plane for each node's output
for (IOIndex_t outputIdx = 0; outputIdx < node->nbOutputs(); ++outputIdx) {
const size_t requiredSize = op->getRequiredMemory(outputIdx, {});
// By default, specifies a fully monolithic memory block
size_t size = requiredSize;
size_t stride = 0;
size_t length = 1;
size_t count = 1;
if (op->getOutput(outputIdx) && op->getOutput(outputIdx)->dims().size() > 3) {
// If it is possible, assume a NCHW layout
size = op->getOutput(outputIdx)->dims().end()[-3];
stride = size;
length = op->getOutput(outputIdx)->dims().end()[-1];
count = op->getOutput(outputIdx)->dims().end()[-2];
}
// Check if wrap around buffer is possible for this node
// (re-using previous node outputs memory for this node outputs).
// => only if this node is the only child of its parent(s)
size_t wrapAroundSize = 0;
size_t wrapAroundExtra = 0;
wrapAroundMemPlane.push_back(nullptr);
// Select the best parent among all allocable nodes for
// reallocation, which is the one with most memory (in order
// to minimize the reallocation size).
IOIndex_t inputIdx = 0;
for (const auto& parent : node->dataInputs()) {
if (parent.first && parent.first->getChildren(parent.second).size() == 1
// there might be no existing plane if the parent was
// not yet scheduled (because it may be a recurrent connection)
&& memManager.getNbPlanes(parent.first) >= parent.first->nbOutputs()
// memSpace should not be already released
&& memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second].memSpace->released == -1)
{
const bool isWrappable = (op->getNbRequiredProtected(inputIdx) < op->getNbRequiredData(inputIdx));
const MemoryManager::MemoryPlane& memPlane = memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second];
if (isWrappable || !memManager.isWrapAround(
memPlane.memSpace,
memPlane.getFinalOffset()
- memPlane.memSpace->offset,
requiredSize))
{
if (memPlane.getSize() > wrapAroundSize + op->getNbRequiredProtected(inputIdx)
&& std::find(wrapAroundMemPlane.begin(), wrapAroundMemPlane.end(), &memPlane) == wrapAroundMemPlane.end())
{
wrapAroundSize = memPlane.getSize() - op->getNbRequiredProtected(inputIdx);
if (requiredSize > wrapAroundSize) {
wrapAroundExtra = requiredSize - wrapAroundSize;
}
wrapAroundMemPlane[outputIdx] = &memPlane;
}
if (wrapAroundExtra == 0) {
break;
}
}
}
++inputIdx;
}
// MemoryPlane to (re)use
const MemoryManager::MemoryPlane& memPlane
= (wrapAroundBuffer && wrapAroundSize > 0)
? (*wrapAroundMemPlane[outputIdx]) :
memManager.allocate(requiredSize, childs, stride, length, count);
if (wrapAroundBuffer && wrapAroundSize > 0) {
memManager.reallocate(memPlane,
node, 0,
requiredSize, true, wrapAroundExtra, childs, stride, length, count);
}
else {
memManager.reallocate(memPlane.memSpace,
node, memPlane.offset,
requiredSize, false, 0, childs, stride, length, count);
}
}
memManager.releaseDependencies(node);
memManager.tick();
}
}
return memManager;
}
Thibault Allenet
committed
void Aidge::SequentialScheduler::connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data){
// This version of connect inputs only connects tensor inputs in input data producers.
auto inputNodes = mGraphView->getOrderedInputs();
// Assert that the number of input data producers corresponds to the number of data input
assert(data.size() == inputNodes.size() && "Scheduler connectInput error - Inconsistent number of graph inputs and inputs passed to the graph");
for (std::size_t i = 0; i < data.size(); ++i){
// TODO : maybe shallow copy instead of deepcopy
inputNodes[i].first->getOperator()->setInput(inputNodes[i].second, data[i]);
}
}
void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::vector<std::shared_ptr<Aidge::Tensor>> data) {
// Collect all data input of the graph (that are producers)
if (!data.empty()){
connectInputs(data);
}
if (forwardDims) {mGraphView->forwardDims(); }
// Generate scheduling *only if empty*
// If scheduling was already generated (in one or several steps, i.e. one or
// several successive call to generateScheduling()), do not generate it twice
if (mStaticSchedule.empty()) {
const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
size_t cpt = 0;
for (const auto& runnable : getStaticScheduling(mStaticScheduleStep)) {
if (verbose)
drawProgressBar(static_cast<float>(cpt) / static_cast<float>(mStaticSchedule.size()), 50,
const auto tStart = std::chrono::high_resolution_clock::now();
runnable->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
mScheduling.push_back(SchedulingElement(runnable, tStart, tEnd));
++mStaticScheduleStep;
if (mStaticScheduleStep == mStaticSchedule.size()) {
mStaticScheduleStep = 0;
}
}
void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const {
auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose);
if (!fp) {
AIDGE_THROW_OR_ABORT(std::runtime_error,
"Could not create scheduling diagram log file: {}", fileName + ".mmd");
}
fmt::print(fp.get(), "gantt\ndateFormat x\naxisFormat %Q µs\n\n");
Olivier BICHLER
committed
const std::map<std::shared_ptr<Node>, std::string> namePtrTable
= mGraphView->getRankedNodesName("{0} ({1}#{3})");
const auto globalStart = mScheduling[0].start;
for (const auto& element : mScheduling) {
auto name = namePtrTable.at(element.node);
// Mermaid does not allow : character in task title
std::replace(name.begin(), name.end(), ':', '_');
std::chrono::duration_cast<std::chrono::microseconds>(element.start - globalStart).count(),
std::chrono::duration_cast<std::chrono::microseconds>(element.end - globalStart).count());
}
}
void Aidge::SequentialScheduler::saveStaticSchedulingDiagram(const std::string& fileName) const {
auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose);
if (!fp) {
AIDGE_THROW_OR_ABORT(std::runtime_error,
"Could not create scheduling diagram log file: {}", fileName + ".mmd");
}
fmt::print(fp.get(), "gantt\ndateFormat x\naxisFormat %Q\n\n");
const std::map<std::shared_ptr<Node>, std::string> namePtrTable
= mGraphView->getRankedNodesName("{0} ({1}#{3})");
for (const auto& schedule : mStaticSchedule) {
for (const auto& element : schedule) {
auto name = namePtrTable.at(element->node);
// Mermaid does not allow : character in task title
std::replace(name.begin(), name.end(), ':', '_');
}
}
fmt::print(fp.get(), "\n");
}
std::vector<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getStaticScheduling(size_t step) const {
const auto& staticSchedule = mStaticSchedule.at(step);
std::vector<std::shared_ptr<Node>> schedule;
std::transform(staticSchedule.begin(), staticSchedule.end(), std::back_inserter(schedule), [](const auto& v) { return v->node; });
std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers(
const std::set<std::shared_ptr<Node>>& producers) const {
std::set<std::shared_ptr<Node>> consumers;
for (const auto& producer : producers) {
const auto& childs = producer->getChildren();
for (const auto& child : childs) {
// Do not schedule childs outside current graph!
if (mGraphView->inView(child)) {
consumers.insert(child);
}
}

Cyril Moineau
committed
}
Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const {
const auto parent = node->inputs()[inputIdx];
if (parent.first) {
// Parent is connected, everything if fine!
return parent.first->getOperator()->getNbProducedData(parent.second);
}
else if (std::shared_ptr<Node> upperNode = mUpperNode.lock()) {
// We are inside an upper operator (for instance a MetaOperator)
// We need to connect the "local" producer-consumer model to the upper
// one, by mapping local node inputs to the upper node inputs.
IOIndex_t nodeInputIdx = 0;
for (const auto& input : mGraphView->getOrderedInputs()) {
if (input.first == node) {
// Current node is an input
const auto upperInput = upperNode->inputs()[nodeInputIdx];
if (upperInput.first) {
return upperInput.first->getOperator()->getNbProducedData(upperInput.second);
}
}
}
}
// Otherwise, two cases:
if (node->getOperator()->getRawInput(inputIdx)) {
// Input is not connected but a valid tensor exists
// => This means data was fed manually to the input, without a Producer
// In this case, we assume a single-use data (unlike a Producer, which
// keep producing the data each time it is needed).
fmt::print("No producer node attached to input#{} for node {} ({})\n", inputIdx, node->name(), node->type());
return std::static_pointer_cast<Tensor>(node->getOperator()->getRawInput(inputIdx))->size();
}
else {
// Input is not connected, this is an error
AIDGE_THROW_OR_ABORT(std::runtime_error, "Missing input#{} for node {} ({})\n", inputIdx, node->name(), node->type());
}
return 0;
}
Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::getPriorProducersConsumers(
const std::shared_ptr<Node>& node) const
{
PriorProducersConsumers prior;
IOIndex_t inputIdx = 0;
for (const auto& parent : node->inputs()) {
if (parent.first &&
(node->getOperator()->getNbConsumedData(inputIdx) + node->getOperator()->getNbRequiredData(inputIdx)) >
parent.first->getOperator()->getNbProducedData(parent.second))
{
if (!mGraphView->inView(parent.first)) {
// Do not schedule prior outside the current graph!
return PriorProducersConsumers();
}
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
if (parent.first->type() == Producer_Op::Type) {
prior.requiredProducers.insert(parent.first);
prior.priorConsumers.insert(node);
}
else if (parent.first->type() == Memorize_Op::Type) {
// Break cycles
return PriorProducersConsumers();
}
else {
const auto& parentPrior = getPriorProducersConsumers(parent.first);
if (!parentPrior.isPrior) {
return PriorProducersConsumers();
}
else {
prior.requiredProducers.insert(parentPrior.requiredProducers.cbegin(), parentPrior.requiredProducers.cend());
prior.priorConsumers.insert(parentPrior.priorConsumers.cbegin(), parentPrior.priorConsumers.cend());
}
}
}
++inputIdx;
}
prior.isPrior = true;
if (prior.priorConsumers.empty()) {
prior.priorConsumers.insert(node);
}
return prior;
}
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
void Aidge::ParallelScheduler::forward(bool forwardDims, bool /*verbose*/, std::vector<std::shared_ptr<Aidge::Tensor>> data) {
// Collect all data input of the graph (that are producers)
if (!data.empty()){
connectInputs(data);
}
// Forward dims (if allowed)
if (forwardDims) {mGraphView->forwardDims(); }
// Generate scheduling *only if empty*
// If scheduling was already generated (in one or several steps, i.e. one or
// several successive call to generateScheduling()), do not generate it twice
if (mStaticSchedule.empty()) {
this->generateScheduling();
}
const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
// Sort static scheduling, the order will be the prefered threads scheduling
// order for non critical nodes
std::deque<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end());
std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
[](const auto& lhs, const auto& rhs) { return ((lhs->early < rhs->early) || (lhs->early == rhs->early && lhs->late < rhs->late)); });
// The thread pool has N threads running to process nodes.
// Thread pooling avoid the overhead of threads creation and deletion for
// each node execution.
ThreadPool pool;
pool.start();
size_t latest = 0;
std::mutex schedulingMutex;
std::vector<int> required(staticSchedule.back()->late + 1, 0);
std::vector<std::atomic<int>> finished(staticSchedule.back()->late + 1);
std::fill(finished.begin(), finished.end(), 0);
while (!staticSchedule.empty()) {
Log::debug("Step {}", latest);
// Run all nodes that must be run at latest
for (size_t i = 0; i < staticSchedule.size(); ) {
auto runnable = staticSchedule[i];
if (runnable->late == latest) {
// Critical path
pool.queueJob([node = runnable->node, &finished = finished[latest], &schedulingMutex, this]() {
const auto tStart = std::chrono::high_resolution_clock::now();
node->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
++finished;
{
std::unique_lock<std::mutex> lock(schedulingMutex);
mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd));
}
});
staticSchedule.erase(staticSchedule.begin() + i);
++required[latest];
Log::debug(" run critical {}", namePtrTable.at(runnable->node));
for (auto elt : runnable->earlierThan) {
if (elt->early < latest + 1) {
Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1);
elt->early = latest + 1;
}
}
}
else if (runnable->early > latest + 1) {
// There cannot be more node that must be run at latest + 1
// latest + 1 and not latest because early may have been updated
// for some elements to latest + 1 (above).
break;
}
else {
++i;
}
}
// If some threads are still available, run next early nodes
while (!pool.busy() && !staticSchedule.empty()) {
auto runnable = staticSchedule.front();
if (runnable->early <= latest) {
pool.queueJob([node = runnable->node, &finished = finished[runnable->late], &schedulingMutex, this]() {
const auto tStart = std::chrono::high_resolution_clock::now();
node->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
++finished;
{
std::unique_lock<std::mutex> lock(schedulingMutex);
mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd));
}
});
staticSchedule.pop_front();
++required[runnable->late];
Log::debug(" run {}", namePtrTable.at(runnable->node));
for (auto elt : runnable->earlierThan) {
if (elt->early < latest + 1) {
Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1);
elt->early = latest + 1;
}
}
}
else {
break;
}
}
// Wait for all nodes that must finish at latest to be finished
while (finished[latest] < required[latest]) {
std::this_thread::yield();
}
++latest;
}
pool.stop();
++mStaticScheduleStep;
if (mStaticScheduleStep == mStaticSchedule.size()) {
mStaticScheduleStep = 0;
}
}