Newer
Older
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/scheduler/Scheduler.hpp"
#include <chrono>
#include <memory>
#include <set>
#include <string>
Olivier BICHLER
committed
#include <fmt/ranges.h>
Olivier BICHLER
committed
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
#include "aidge/operator/Memorize.hpp"
#include "aidge/operator/MetaOperator.hpp"
void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") {
putchar('[');
int pos = static_cast<int>(barWidth * progress);
for (int i = 0; i < barWidth; ++i) {
if (i <= pos)
putchar('#');
else
putchar(' ');
}
fmt::print("] {}% | {}\r", static_cast<int>(progress * 100), additionalInfo);
void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
// TODO: For loop on the list of node to run
// run sequencially every runnable consumers once
// TODO: handle memory allocation in scheduler
// TODO: optimize memory usage
// 1) Setup initial consumers list:
// It is the list of input nodes
std::set<std::shared_ptr<Node>> consumers = mGraphView->inputNodes();
// Plus the list of nodes inside the graph connected to an inner producer
std::set<std::shared_ptr<Node>> producers;
for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) {
const auto producersConsumers = getConsumers(producers);
consumers.insert(producersConsumers.begin(), producersConsumers.end());
Olivier BICHLER
committed
std::map<std::shared_ptr<Node>, std::string> namePtrTable;
if (verbose) namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
// Still consumers are consumers that were run by can still consume data.
// They must be run AFTER the remaining consumer to ensure a non-greedy
// producers-consumers model!
std::set<std::shared_ptr<Node>> stillConsumers;
mStaticSchedule.push_back(std::vector<std::shared_ptr<Node>>());
// 2) From the current consumers list, check if any prior consumer node
// is needed. A prior will generally be required for any node consuming
// parameters (weights and bias) that is not an input node.
// If for a given node, only parent producers (at any depth) are needed
// to satisfy its required data, it becomes a prior.
// If the prior node is a producer, it is added to the list of required
// producers.
// If the prior node is of another type, it replaces the initial consumer
// in the new priorConsumers list. The initial consumer will become
// again a consumer later, by construction.
if (verbose) fmt::print("List of consumers with their priors:\n");
std::set<std::shared_ptr<Node>> priorConsumers;
for (const auto& consumer : consumers) {
if (verbose) {
fmt::print("\t- consumer: ");
fmt::print(fg(fmt::color::orange), namePtrTable[consumer]);
fmt::print("\n");
const auto& prior = getPriorProducersConsumers(consumer);
if (prior.isPrior) {
if (verbose) {
Olivier BICHLER
committed
std::vector<std::string> requiredProducersName;
std::transform(prior.requiredProducers.begin(), prior.requiredProducers.end(),
std::back_inserter(requiredProducersName),
Olivier BICHLER
committed
fmt::print("\t\trequired producers: {}\n", requiredProducersName);
std::vector<std::string> priorConsumersName;
std::transform(prior.priorConsumers.begin(), prior.priorConsumers.end(),
std::back_inserter(priorConsumersName),
Olivier BICHLER
committed
fmt::print("\t\tprior consumers: {}\n", priorConsumersName);
requiredProducers.insert(prior.requiredProducers.cbegin(), prior.requiredProducers.cend());
priorConsumers.insert(prior.priorConsumers.cbegin(), prior.priorConsumers.cend());
}
else {
priorConsumers.insert(consumer);
// 3) Prior consumers replace the initial consumers list.
// By construction, initial consumers will necessarily become consumers
// again later.
consumers.swap(priorConsumers);
// 4) Make producers generate the required data.
// Producers are special nodes that generate data on demand.
for (const auto& requiredProducer : requiredProducers) {
requiredProducer->getOperator()->updateConsummerProducer();
mStaticSchedule.back().push_back(requiredProducer);
// 5) Find runnable consumers.
// A consumer is runnable if the required data is available for all of
// its inputs. At this point, not all consumers are necessarily
// runnable because some may depend on the execution of others (when
// there is multiple successive priors for example).
for (const auto& consumer : consumers) {
if (verbose) {
fmt::print("\t- consumer: ");
fmt::print(fg(fmt::color::orange), namePtrTable[consumer]);
fmt::print("\n\t\tC/R:\t");
for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) {
fmt::print("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId),
fmt::print("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1),
consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1));
for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) {
fmt::print("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId));
fmt::print("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1));
fmt::print("\n");
for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) {
if (/*consumer->getOperator()->getNbRequiredData(inputIdx) > 0
&& */(consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) >
getNbAvailableData(consumer, inputIdx)) {
if (verbose) fmt::print(" not runnable: C{} + R{} > P{} for input #{}\n",
consumer->getOperator()->getNbConsumedData(inputIdx),
consumer->getOperator()->getNbRequiredData(inputIdx),
getNbAvailableData(consumer, inputIdx), inputIdx);
// not enough data to run
isRunnable = false;
break;
}
}
if (isRunnable) {
runnableConsumers.insert(consumer);
}
}
// 5) If not consumer is runnable, it is a stop condition!
if (runnableConsumers.empty()) {
// No consumer is runnable: some required data is missing for all of
// them. There is two possibilities:
// - At least one required data source is exhausted, which may be
// an expected stop condition.
// - There is a deadlock between consumers, if some one is waiting
// for data from the other and reciprocally.
break;
}
// 6) Push runnable consumers in the list of nodes to run and update the
// consumer producer system.
// At this point, simultaneously runnable consumers have no data
// dependency and could be run in parallel!
if (verbose) fmt::print("Runnable: {}\n", namePtrTable[runnable]);
runnable->getOperator()->updateConsummerProducer();
mStaticSchedule.back().push_back(runnable);
// 7) Update consumers list
if (verbose) fmt::print("Updating producer and consumer lists...\n");
for (const auto& consumer : runnableConsumers) {
fmt::print("\t- consumer: {}\n\t\tC/R:\t",
namePtrTable[consumer]);
for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) {
fmt::print("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId),
fmt::print("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1),
consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1));
for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) {
fmt::print("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId));
fmt::print("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1));
fmt::print("\n");
// 7.1) If the current consumer has still data to consume, it will
// be put back in the consumers list once the remaining consumers
// have been exhausted.
bool isStillConsumer = false;
for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) {
if (consumer->getOperator()->getNbConsumedData(inputIdx) <
getNbAvailableData(consumer, inputIdx)) {
if (verbose) fmt::print(" still consumer: C{} < P{} for input #{}\n",
getNbAvailableData(consumer, inputIdx), inputIdx);
// there is still data to consume
isStillConsumer = true;
break;
}
}
// 7.2) If the current consumer becomes a producer for other nodes,
// its childs become consumers.
bool isProducer = false;
for (IOIndex_t outId = 0; outId < consumer->nbOutputs(); ++outId) {
for (const auto& child : consumer->getChildren(outId)) {
if (child) {
IOIndex_t inputIdx = 0;
for (const auto& childParent : child->getParents()) {
if (childParent == consumer) {
if (consumer->getOperator()->getNbProducedData(outId) > child->getOperator()->getNbConsumedData(inputIdx)) {
isProducer = true;
}
}
++inputIdx;
}
}
}
/*
if (consumer->getOperator()->getNbProducedData(outId) > 0) {
// make sure consumer is also a producer
producers.insert(consumer);
const auto& newConsumers = getConsumers({consumer});
consumers.insert(newConsumers.cbegin(), newConsumers.cend());
consumers.erase(consumer);
if (isProducer) {
// make sure consumer is also a producer
producers.insert(consumer);
const auto& newConsumers = getConsumers({consumer});
consumers.insert(newConsumers.cbegin(), newConsumers.cend());
}
if (isStillConsumer) {
// If there is still data to consume, the consumer will be
// run AFTER the other remaining consumers
// (= non-greedy consumers)
stillConsumers.insert(consumer);
// 8) If there is no more consumers, swap with possible "still consumers"
// This ensures that the "non-greedy" consumer behavior
if (consumers.empty()) {
consumers.swap(stillConsumers);
stillConsumers.clear();
}
} while (!consumers.empty());
fmt::print("/!\\ Remaining consumers: possible dead-lock\n");
fmt::print("********************\n");
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
std::vector<Aidge::SequentialScheduler::StaticSchedulingElement>
Aidge::SequentialScheduler::generateEarlyLateScheduling() const {
std::vector<StaticSchedulingElement> scheduling;
size_t latest = 0;
for (size_t elt = 0; elt < mStaticSchedule.at(0).size(); ++elt) {
const auto node = mStaticSchedule.at(0)[elt];
const auto itNode = std::find_if(scheduling.rbegin(), scheduling.rend(), [node](const auto& v) { return (v.node == node); });
// Find early step: node can be run after the latest parent was run
// also, node must be run after latest node run!
size_t early = 0;
if (itNode != scheduling.rend()) {
early = (*itNode).early + 1;
}
for (const auto parent : node->getParents()) {
// Find parent node latest scheduled position
const auto it = std::find(mStaticSchedule.at(0).rend() - elt, mStaticSchedule.at(0).rend(), parent);
if (it != mStaticSchedule.at(0).rend()) {
const size_t step = std::distance(mStaticSchedule.at(0).begin(), it.base()) - 1;
early = std::max(early, scheduling[step].early + 1);
}
}
/*
// Update late step for parents
for (const auto parent : node->getParents()) {
const auto it = std::find(mStaticSchedule.at(0).rend() - elt, mStaticSchedule.at(0).rend(), parent);
if (it != mStaticSchedule.at(0).rend()) {
const size_t step = std::distance(mStaticSchedule.at(0).begin(), it.base()) - 1;
scheduling[step].late = std::min(scheduling[step].late, early - 1);
latest = std::max(latest, scheduling[step].late);
}
}
*/
latest = std::max(latest, early);
size_t late = static_cast<size_t>(-1);
scheduling.push_back(StaticSchedulingElement(node, early, late));
}
for (size_t elt = mStaticSchedule.at(0).size(); elt-- != 0; ) {
const auto node = mStaticSchedule.at(0)[elt];
const auto itNode = std::find_if(scheduling.begin() + elt + 1, scheduling.end(), [node](const auto& v) { return (v.node == node); });
size_t late = latest;
if (itNode != scheduling.end()) {
late = (*itNode).late - 1;
}
for (const auto child : node->getChildren()) {
// Find child node earliest scheduled position
const auto it = std::find(mStaticSchedule.at(0).begin() + elt + 1, mStaticSchedule.at(0).end(), child);
if (it != mStaticSchedule.at(0).end()) {
const size_t step = std::distance(mStaticSchedule.at(0).begin(), it);
late = std::min(late, scheduling[step].late - 1);
}
}
scheduling[elt].late = late;
}
return scheduling;
}
void Aidge::SequentialScheduler::resetScheduling() {
for (auto node : mGraphView->getNodes()) {
node->getOperator()->resetConsummerProducer();
}
mStaticSchedule.clear();
mStaticScheduleStep = 0;
mScheduling.clear();
}
/**
* This version is a simplified version without special handling of concatenation.
*/
Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducers, bool wrapAroundBuffer) const {
MemoryManager memManager;
for (const auto& shedule : mStaticSchedule) {
for (const auto& node : shedule) {
if (!incProducers && node->type() == Producer_Op::Type) {
memManager.releaseDependencies(node);
continue;
}
const auto childs = node->getChildren();
AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane;
// Allocate a memory plane for each node's output
for (IOIndex_t outputIdx = 0; outputIdx < node->nbOutputs(); ++outputIdx) {
const size_t requiredSize = op->getRequiredMemory(outputIdx, {});
// By default, specifies a fully monolithic memory block
size_t size = requiredSize;
size_t stride = 0;
size_t length = 1;
size_t count = 1;
if (op->getOutput(outputIdx) && op->getOutput(outputIdx)->dims().size() > 3) {
// If it is possible, assume a NCHW layout
size = op->getOutput(outputIdx)->dims().end()[-3];
stride = size;
length = op->getOutput(outputIdx)->dims().end()[-1];
count = op->getOutput(outputIdx)->dims().end()[-2];
}
// Check if wrap around buffer is possible for this node
// (re-using previous node outputs memory for this node outputs).
// => only if this node is the only child of its parent(s)
size_t wrapAroundSize = 0;
size_t wrapAroundExtra = 0;
wrapAroundMemPlane.push_back(nullptr);
// Select the best parent among all allocable nodes for
// reallocation, which is the one with most memory (in order
// to minimize the reallocation size).
IOIndex_t inputIdx = 0;
for (const auto& parent : node->dataInputs()) {
if (parent.first && parent.first->getChildren(parent.second).size() == 1
// there might be no existing plane if the parent was
// not yet scheduled (because it may be a recurrent connection)
&& memManager.getNbPlanes(parent.first) >= parent.first->nbOutputs()
// memSpace should not be already released
&& memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second].memSpace->released == -1)
{
const bool isWrappable = (op->getNbRequiredProtected(inputIdx) < op->getNbRequiredData(inputIdx));
const MemoryManager::MemoryPlane& memPlane = memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second];
if (isWrappable || !memManager.isWrapAround(
memPlane.memSpace,
memPlane.getFinalOffset()
- memPlane.memSpace->offset,
requiredSize))
{
if (memPlane.getSize() > wrapAroundSize + op->getNbRequiredProtected(inputIdx)
&& std::find(wrapAroundMemPlane.begin(), wrapAroundMemPlane.end(), &memPlane) == wrapAroundMemPlane.end())
{
wrapAroundSize = memPlane.getSize() - op->getNbRequiredProtected(inputIdx);
if (requiredSize > wrapAroundSize) {
wrapAroundExtra = requiredSize - wrapAroundSize;
}
wrapAroundMemPlane[outputIdx] = &memPlane;
}
if (wrapAroundExtra == 0) {
break;
}
}
}
++inputIdx;
}
// MemoryPlane to (re)use
const MemoryManager::MemoryPlane& memPlane
= (wrapAroundBuffer && wrapAroundSize > 0)
? (*wrapAroundMemPlane[outputIdx]) :
memManager.allocate(requiredSize, childs, stride, length, count);
if (wrapAroundBuffer && wrapAroundSize > 0) {
memManager.reallocate(memPlane,
node, 0,
requiredSize, true, wrapAroundExtra, childs, stride, length, count);
}
else {
memManager.reallocate(memPlane.memSpace,
node, memPlane.offset,
requiredSize, false, 0, childs, stride, length, count);
}
}
memManager.releaseDependencies(node);
memManager.tick();
}
}
return memManager;
}
Thibault Allenet
committed
void Aidge::SequentialScheduler::connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data){
// This version of connect inputs only connects tensor inputs in input data producers.
auto inputNodes = mGraphView->getOrderedInputs();
// Assert that the number of input data producers corresponds to the number of data input
assert(data.size() == inputNodes.size() && "Scheduler connectInput error - Inconsistent number of graph inputs and inputs passed to the graph");
for (std::size_t i = 0; i < data.size(); ++i){
// TODO : maybe shallow copy instead of deepcopy
inputNodes[i].first->getOperator()->setInput(inputNodes[i].second, data[i]);
}
}
void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::vector<std::shared_ptr<Aidge::Tensor>> data) {
// Collect all data input of the graph (that are producers)
if (!data.empty()){
connectInputs(data);
}
if (forwardDims) {mGraphView->forwardDims(); }
// Generate scheduling *only if empty*
// If scheduling was already generated (in one or several steps, i.e. one or
// several successive call to generateScheduling()), do not generate it twice
if (mStaticSchedule.empty()) {
const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
size_t cpt = 0;
for (const auto& runnable : mStaticSchedule.at(mStaticScheduleStep)) {
if (verbose)
drawProgressBar(static_cast<float>(cpt) / static_cast<float>(mStaticSchedule.size()), 50,
const auto tStart = std::chrono::high_resolution_clock::now();
runnable->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
mScheduling.push_back(SchedulingElement(runnable, tStart, tEnd));
++mStaticScheduleStep;
if (mStaticScheduleStep == mStaticSchedule.size()) {
mStaticScheduleStep = 0;
}
}
void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const {
auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose);
if (!fp) {
AIDGE_THROW_OR_ABORT(std::runtime_error,
"Could not create scheduling diagram log file: {}", fileName + ".mmd");
}
fmt::print(fp.get(), "gantt\ndateFormat x\naxisFormat %Q µs\n\n");
Olivier BICHLER
committed
const std::map<std::shared_ptr<Node>, std::string> namePtrTable
= mGraphView->getRankedNodesName("{0} ({1}#{3})");
const auto globalStart = mScheduling[0].start;
for (const auto& element : mScheduling) {
auto name = namePtrTable.at(element.node);
// Mermaid does not allow : character in task title
std::replace(name.begin(), name.end(), ':', '_');
std::chrono::duration_cast<std::chrono::microseconds>(element.start - globalStart).count(),
std::chrono::duration_cast<std::chrono::microseconds>(element.end - globalStart).count());
}
}
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
void Aidge::SequentialScheduler::saveStaticSchedulingDiagram(const std::string& fileName, const std::vector<StaticSchedulingElement>& scheduling) const {
auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose);
if (!fp) {
AIDGE_THROW_OR_ABORT(std::runtime_error,
"Could not create scheduling diagram log file: {}", fileName + ".mmd");
}
fmt::print(fp.get(), "gantt\ndateFormat x\naxisFormat %Q\n\n");
if (!scheduling.empty()) {
const std::map<std::shared_ptr<Node>, std::string> namePtrTable
= mGraphView->getRankedNodesName("{0} ({1}#{3})");
for (const auto& element : scheduling) {
auto name = namePtrTable.at(element.node);
// Mermaid does not allow : character in task title
std::replace(name.begin(), name.end(), ':', '_');
fmt::print(fp.get(), "{} :{}, {}\n",
name, element.early, element.late);
}
}
fmt::print(fp.get(), "\n");
}
std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers(
const std::set<std::shared_ptr<Node>>& producers) const {
std::set<std::shared_ptr<Node>> consumers;
for (const auto& producer : producers) {
const auto& childs = producer->getChildren();
for (const auto& child : childs) {
// Do not schedule childs outside current graph!
if (mGraphView->inView(child)) {
consumers.insert(child);
}
}

Cyril Moineau
committed
}
Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const {
const auto parent = node->inputs()[inputIdx];
if (parent.first) {
// Parent is connected, everything if fine!
return parent.first->getOperator()->getNbProducedData(parent.second);
}
else if (std::shared_ptr<Node> upperNode = mUpperNode.lock()) {
// We are inside an upper operator (for instance a MetaOperator)
// We need to connect the "local" producer-consumer model to the upper
// one, by mapping local node inputs to the upper node inputs.
IOIndex_t nodeInputIdx = 0;
for (const auto& input : mGraphView->getOrderedInputs()) {
if (input.first == node) {
// Current node is an input
const auto upperInput = upperNode->inputs()[nodeInputIdx];
if (upperInput.first) {
return upperInput.first->getOperator()->getNbProducedData(upperInput.second);
}
}
}
}
// Otherwise, two cases:
if (node->getOperator()->getRawInput(inputIdx)) {
// Input is not connected but a valid tensor exists
// => This means data was fed manually to the input, without a Producer
// In this case, we assume a single-use data (unlike a Producer, which
// keep producing the data each time it is needed).
fmt::print("No producer node attached to input#{} for node {} ({})\n", inputIdx, node->name(), node->type());
return std::static_pointer_cast<Tensor>(node->getOperator()->getRawInput(inputIdx))->size();
}
else {
// Input is not connected, this is an error
AIDGE_THROW_OR_ABORT(std::runtime_error, "Missing input#{} for node {} ({})\n", inputIdx, node->name(), node->type());
}
return 0;
}
Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::getPriorProducersConsumers(
const std::shared_ptr<Node>& node) const
{
PriorProducersConsumers prior;
IOIndex_t inputIdx = 0;
for (const auto& parent : node->inputs()) {
if (parent.first &&
(node->getOperator()->getNbConsumedData(inputIdx) + node->getOperator()->getNbRequiredData(inputIdx)) >
parent.first->getOperator()->getNbProducedData(parent.second))
{
if (!mGraphView->inView(parent.first)) {
// Do not schedule prior outside the current graph!
return PriorProducersConsumers();
}
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
if (parent.first->type() == Producer_Op::Type) {
prior.requiredProducers.insert(parent.first);
prior.priorConsumers.insert(node);
}
else if (parent.first->type() == Memorize_Op::Type) {
// Break cycles
return PriorProducersConsumers();
}
else {
const auto& parentPrior = getPriorProducersConsumers(parent.first);
if (!parentPrior.isPrior) {
return PriorProducersConsumers();
}
else {
prior.requiredProducers.insert(parentPrior.requiredProducers.cbegin(), parentPrior.requiredProducers.cend());
prior.priorConsumers.insert(parentPrior.priorConsumers.cbegin(), parentPrior.priorConsumers.cend());
}
}
}
++inputIdx;
}
prior.isPrior = true;
if (prior.priorConsumers.empty()) {
prior.priorConsumers.insert(node);
}
return prior;
}