/******************************************************************************** * Copyright (c) 2023 CEA-List * * This program and the accompanying materials are made available under the * terms of the Eclipse Public License 2.0 which is available at * http://www.eclipse.org/legal/epl-2.0. * * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ #include "aidge/scheduler/Scheduler.hpp" #include "aidge/scheduler/ThreadPool.hpp" #include <chrono> #include <memory> #include <set> #include <string> #include <fmt/ranges.h> #include <fmt/color.h> #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Node.hpp" #include "aidge/utils/Types.h" #include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/operator/Memorize.hpp" #include "aidge/operator/MetaOperator.hpp" void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") { putchar('['); int pos = static_cast<int>(barWidth * progress); for (int i = 0; i < barWidth; ++i) { if (i <= pos) putchar('#'); else putchar(' '); } fmt::print("] {}% | {}\r", static_cast<int>(progress * 100), additionalInfo); fflush(stdout); } void Aidge::SequentialScheduler::generateScheduling() { auto schedule = generateBaseScheduling(); generateEarlyLateScheduling(schedule); mStaticSchedule.push_back(schedule); } std::vector<std::shared_ptr<Aidge::SequentialScheduler::StaticSchedulingElement>> Aidge::SequentialScheduler::generateBaseScheduling() const { // TODO: For loop on the list of node to run // run sequencially every runnable consumers once // TODO: handle memory allocation in scheduler // TODO: optimize memory usage // 1) Setup initial consumers list: // It is the list of input nodes std::set<std::shared_ptr<Node>> consumers = mGraphView->inputNodes(); // Plus the list of nodes inside the graph connected to an inner producer std::set<std::shared_ptr<Node>> producers; for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) { if (nodePtr->type() == Producer_Op::Type) { producers.insert(nodePtr); } } const auto producersConsumers = getConsumers(producers); consumers.insert(producersConsumers.begin(), producersConsumers.end()); const std::map<std::shared_ptr<Node>, std::string> namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); // Still consumers are consumers that were run by can still consume data. // They must be run AFTER the remaining consumer to ensure a non-greedy // producers-consumers model! std::set<std::shared_ptr<Node>> stillConsumers; std::vector<std::shared_ptr<StaticSchedulingElement>> schedule; do { // 2) From the current consumers list, check if any prior consumer node // is needed. A prior will generally be required for any node consuming // parameters (weights and bias) that is not an input node. // If for a given node, only parent producers (at any depth) are needed // to satisfy its required data, it becomes a prior. // If the prior node is a producer, it is added to the list of required // producers. // If the prior node is of another type, it replaces the initial consumer // in the new priorConsumers list. The initial consumer will become // again a consumer later, by construction. Log::debug("List of consumers with their priors:"); std::set<std::shared_ptr<Node>> requiredProducers; std::set<std::shared_ptr<Node>> priorConsumers; for (const auto& consumer : consumers) { Log::debug("\t- consumer: {}", fmt::styled(namePtrTable.at(consumer), fg(fmt::color::orange))); const auto& prior = getPriorProducersConsumers(consumer); if (prior.isPrior) { std::vector<std::string> requiredProducersName; std::transform(prior.requiredProducers.begin(), prior.requiredProducers.end(), std::back_inserter(requiredProducersName), [&namePtrTable](auto val){ return namePtrTable.at(val); }); Log::debug("\t\trequired producers: {}", requiredProducersName); std::vector<std::string> priorConsumersName; std::transform(prior.priorConsumers.begin(), prior.priorConsumers.end(), std::back_inserter(priorConsumersName), [&namePtrTable](auto val){ return namePtrTable.at(val); }); Log::debug("\t\tprior consumers: {}", priorConsumersName); requiredProducers.insert(prior.requiredProducers.cbegin(), prior.requiredProducers.cend()); priorConsumers.insert(prior.priorConsumers.cbegin(), prior.priorConsumers.cend()); } else { priorConsumers.insert(consumer); } } // 3) Prior consumers replace the initial consumers list. // By construction, initial consumers will necessarily become consumers // again later. consumers.swap(priorConsumers); // 4) Make producers generate the required data. // Producers are special nodes that generate data on demand. for (const auto& requiredProducer : requiredProducers) { requiredProducer->getOperator()->updateConsummerProducer(); schedule.push_back(std::make_shared<StaticSchedulingElement>(requiredProducer)); } // 5) Find runnable consumers. // A consumer is runnable if the required data is available for all of // its inputs. At this point, not all consumers are necessarily // runnable because some may depend on the execution of others (when // there is multiple successive priors for example). std::set<std::shared_ptr<Node>> runnableConsumers; Log::debug("Updated list of consumers:"); for (const auto& consumer : consumers) { Log::debug("\t- consumer: {}", fmt::styled(namePtrTable.at(consumer), fg(fmt::color::orange))); std::string crLog = "\t\tC/R:\t"; for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { crLog += fmt::format("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), consumer->getOperator()->getNbRequiredData(inId)); } crLog += fmt::format("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1)); Log::debug("{}", crLog); std::string pLog = "\t\tP:\t"; for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) { pLog += fmt::format("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); } pLog += fmt::format("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); Log::debug("{}", pLog); bool isRunnable = true; for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) { if (/*consumer->getOperator()->getNbRequiredData(inputIdx) > 0 && */(consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) > getNbAvailableData(consumer, inputIdx)) { Log::debug(" not runnable: C{} + R{} > P{} for input #{}", consumer->getOperator()->getNbConsumedData(inputIdx), consumer->getOperator()->getNbRequiredData(inputIdx), getNbAvailableData(consumer, inputIdx), inputIdx); // not enough data to run isRunnable = false; break; } } if (isRunnable) { runnableConsumers.insert(consumer); } } // 5) If not consumer is runnable, it is a stop condition! if (runnableConsumers.empty()) { Log::debug("********************"); // No consumer is runnable: some required data is missing for all of // them. There is two possibilities: // - At least one required data source is exhausted, which may be // an expected stop condition. // - There is a deadlock between consumers, if some one is waiting // for data from the other and reciprocally. break; } // 6) Push runnable consumers in the list of nodes to run and update the // consumer producer system. // At this point, simultaneously runnable consumers have no data // dependency and could be run in parallel! for (const auto& runnable : runnableConsumers) { Log::debug("Runnable: {}", namePtrTable.at(runnable)); runnable->getOperator()->updateConsummerProducer(); schedule.push_back(std::make_shared<StaticSchedulingElement>(runnable)); } // 7) Update consumers list Log::debug("Updating producer and consumer lists..."); for (const auto& consumer : runnableConsumers) { Log::debug("\t- consumer: {}", fmt::styled(namePtrTable.at(consumer), fg(fmt::color::orange))); std::string crLog = "\t\tC/R:\t"; for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { crLog += fmt::format("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), consumer->getOperator()->getNbRequiredData(inId)); } crLog += fmt::format("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1)); Log::debug("{}", crLog); std::string pLog = "\t\tP:\t"; for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) { pLog += fmt::format("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); } pLog += fmt::format("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); Log::debug("{}", pLog); // 7.1) If the current consumer has still data to consume, it will // be put back in the consumers list once the remaining consumers // have been exhausted. bool isStillConsumer = false; for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) { if (consumer->getOperator()->getNbConsumedData(inputIdx) < getNbAvailableData(consumer, inputIdx)) { Log::debug(" still consumer: C{} < P{} for input #{}", consumer->getOperator()->getNbConsumedData(inputIdx), getNbAvailableData(consumer, inputIdx), inputIdx); // there is still data to consume isStillConsumer = true; break; } } // 7.2) If the current consumer becomes a producer for other nodes, // its childs become consumers. bool isProducer = false; for (IOIndex_t outId = 0; outId < consumer->nbOutputs(); ++outId) { for (const auto& child : consumer->getChildren(outId)) { if (child) { IOIndex_t inputIdx = 0; for (const auto& childParent : child->getParents()) { if (childParent == consumer) { if (consumer->getOperator()->getNbProducedData(outId) > child->getOperator()->getNbConsumedData(inputIdx)) { isProducer = true; } } ++inputIdx; } } } /* if (consumer->getOperator()->getNbProducedData(outId) > 0) { Log::debug(" also producer"); // make sure consumer is also a producer producers.insert(consumer); const auto& newConsumers = getConsumers({consumer}); consumers.insert(newConsumers.cbegin(), newConsumers.cend()); break; } */ } consumers.erase(consumer); if (isProducer) { Log::debug(" also producer"); // make sure consumer is also a producer producers.insert(consumer); const auto& newConsumers = getConsumers({consumer}); consumers.insert(newConsumers.cbegin(), newConsumers.cend()); } if (isStillConsumer) { // If there is still data to consume, the consumer will be // run AFTER the other remaining consumers // (= non-greedy consumers) stillConsumers.insert(consumer); } } // 8) If there is no more consumers, swap with possible "still consumers" // This ensures that the "non-greedy" consumer behavior if (consumers.empty()) { consumers.swap(stillConsumers); stillConsumers.clear(); } Log::debug("********************"); } while (!consumers.empty()); if (!consumers.empty()) { Log::warn("Remaining consumers: possible dead-lock"); } return schedule; } void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const { size_t latest = 0; // Calculate early (logical) start for (size_t elt = 0; elt < schedule.size(); ++elt) { const auto node = schedule[elt]->node; const auto itNode = std::find_if(schedule.rend() - elt, schedule.rend(), [node](const auto& v) { return (v->node == node); }); // Node can be run the earliest just after its childs were run the last time! size_t early = 0; if (itNode != schedule.rend()) { for (const auto& child : node->getChildren()) { // Find child node next scheduled position const auto it = std::find_if(schedule.rend() - elt, itNode, [child](const auto& v) { return (v->node == child); }); AIDGE_INTERNAL_ASSERT(it != schedule.rend()); const size_t step = std::distance(schedule.begin(), it.base()) - 1; early = std::max(early, schedule[step]->early + 1); schedule[step]->earlierThan.push_back(schedule[elt]); } } // Node can be run the earliest just after its latest parent was run for (const auto& parent : node->getParents()) { // Find parent node latest scheduled position const auto it = std::find_if(schedule.rend() - elt, schedule.rend(), [parent](const auto& v) { return (v->node == parent); }); if (it != schedule.rend()) { const size_t step = std::distance(schedule.begin(), it.base()) - 1; early = std::max(early, schedule[step]->early + 1); schedule[step]->earlierThan.push_back(schedule[elt]); } } latest = std::max(latest, early); schedule[elt]->early = early; } // Calculate late (logical) start for (size_t elt = schedule.size(); elt-- != 0; ) { const auto node = schedule[elt]->node; const auto itNode = std::find_if(schedule.begin() + elt + 1, schedule.end(), [node](const auto& v) { return (v->node == node); }); // Node can be run the latest just before its parents are run the next time! size_t late = latest; if (itNode != schedule.end()) { for (const auto& parent : node->getParents()) { // Find child node next scheduled position const auto it = std::find_if(schedule.begin() + elt + 1, itNode, [parent](const auto& v) { return (v->node == parent); }); AIDGE_INTERNAL_ASSERT(it != schedule.end()); const size_t step = std::distance(schedule.begin(), it); late = std::min(late, schedule[step]->late - 1); schedule[step]->laterThan.push_back(schedule[elt]); } } // Node can be run the latest just before its earliest child is run for (const auto& child : node->getChildren()) { // Find child node earliest scheduled position const auto it = std::find_if(schedule.begin() + elt + 1, schedule.end(), [child](const auto& v) { return (v->node == child); }); if (it != schedule.end()) { const size_t step = std::distance(schedule.begin(), it); late = std::min(late, schedule[step]->late - 1); schedule[step]->laterThan.push_back(schedule[elt]); } } schedule[elt]->late = late; } } void Aidge::SequentialScheduler::resetScheduling() { for (auto node : mGraphView->getNodes()) { node->getOperator()->resetConsummerProducer(); } mStaticSchedule.clear(); mStaticScheduleStep = 0; mScheduling.clear(); } /** * This version is a simplified version without special handling of concatenation. */ Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducers, bool wrapAroundBuffer) const { MemoryManager memManager; for (size_t step = 0; step < mStaticSchedule.size(); ++step) { for (const auto& node : getStaticScheduling(step)) { if (!incProducers && node->type() == Producer_Op::Type) { memManager.releaseDependencies(node); continue; } const auto childs = node->getChildren(); AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type."); const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator()); std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane; // Allocate a memory plane for each node's output for (IOIndex_t outputIdx = 0; outputIdx < node->nbOutputs(); ++outputIdx) { const size_t requiredSize = op->getRequiredMemory(outputIdx, {}); // By default, specifies a fully monolithic memory block size_t size = requiredSize; size_t stride = 0; size_t length = 1; size_t count = 1; if (op->getOutput(outputIdx) && op->getOutput(outputIdx)->dims().size() > 3) { // If it is possible, assume a NCHW layout size = op->getOutput(outputIdx)->dims().end()[-3]; stride = size; length = op->getOutput(outputIdx)->dims().end()[-1]; count = op->getOutput(outputIdx)->dims().end()[-2]; } // Check if wrap around buffer is possible for this node // (re-using previous node outputs memory for this node outputs). // => only if this node is the only child of its parent(s) size_t wrapAroundSize = 0; size_t wrapAroundExtra = 0; wrapAroundMemPlane.push_back(nullptr); // Select the best parent among all allocable nodes for // reallocation, which is the one with most memory (in order // to minimize the reallocation size). IOIndex_t inputIdx = 0; for (const auto& parent : node->dataInputs()) { if (parent.first && parent.first->getChildren(parent.second).size() == 1 // there might be no existing plane if the parent was // not yet scheduled (because it may be a recurrent connection) && memManager.getNbPlanes(parent.first) >= parent.first->nbOutputs() // memSpace should not be already released && memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second].memSpace->released == -1) { const bool isWrappable = (op->getNbRequiredProtected(inputIdx) < op->getNbRequiredData(inputIdx)); const MemoryManager::MemoryPlane& memPlane = memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second]; if (isWrappable || !memManager.isWrapAround( memPlane.memSpace, memPlane.getFinalOffset() - memPlane.memSpace->offset, requiredSize)) { if (memPlane.getSize() > wrapAroundSize + op->getNbRequiredProtected(inputIdx) && std::find(wrapAroundMemPlane.begin(), wrapAroundMemPlane.end(), &memPlane) == wrapAroundMemPlane.end()) { wrapAroundSize = memPlane.getSize() - op->getNbRequiredProtected(inputIdx); if (requiredSize > wrapAroundSize) { wrapAroundExtra = requiredSize - wrapAroundSize; } wrapAroundMemPlane[outputIdx] = &memPlane; } if (wrapAroundExtra == 0) { break; } } } ++inputIdx; } // MemoryPlane to (re)use const MemoryManager::MemoryPlane& memPlane = (wrapAroundBuffer && wrapAroundSize > 0) ? (*wrapAroundMemPlane[outputIdx]) : memManager.allocate(requiredSize, childs, stride, length, count); if (wrapAroundBuffer && wrapAroundSize > 0) { memManager.reallocate(memPlane, node, 0, requiredSize, true, wrapAroundExtra, childs, stride, length, count); } else { memManager.reallocate(memPlane.memSpace, node, memPlane.offset, requiredSize, false, 0, childs, stride, length, count); } } memManager.releaseDependencies(node); memManager.tick(); } } return memManager; } void Aidge::SequentialScheduler::connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data){ // This version of connect inputs only connects tensor inputs in input data producers. auto inputNodes = mGraphView->getOrderedInputs(); // Assert that the number of input data producers corresponds to the number of data input assert(data.size() == inputNodes.size() && "Scheduler connectInput error - Inconsistent number of graph inputs and inputs passed to the graph"); for (std::size_t i = 0; i < data.size(); ++i){ // TODO : maybe shallow copy instead of deepcopy inputNodes[i].first->getOperator()->setInput(inputNodes[i].second, data[i]); } } void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::vector<std::shared_ptr<Aidge::Tensor>> data) { // Collect all data input of the graph (that are producers) if (!data.empty()){ connectInputs(data); } // Forward dims (if allowed) if (forwardDims) {mGraphView->forwardDims(); } // Generate scheduling *only if empty* // If scheduling was already generated (in one or several steps, i.e. one or // several successive call to generateScheduling()), do not generate it twice if (mStaticSchedule.empty()) { this->generateScheduling(); } const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); size_t cpt = 0; for (const auto& runnable : getStaticScheduling(mStaticScheduleStep)) { if (verbose) fmt::print("run: {}\n", namePtrTable.at(runnable)); else drawProgressBar(static_cast<float>(cpt) / static_cast<float>(mStaticSchedule.size()), 50, (std::string("running ") + namePtrTable.at(runnable))); const auto tStart = std::chrono::high_resolution_clock::now(); runnable->forward(); const auto tEnd = std::chrono::high_resolution_clock::now(); mScheduling.push_back(SchedulingElement(runnable, tStart, tEnd)); cpt++; } if (!verbose) drawProgressBar(1.0, 50, " "); fmt::print("\n"); ++mStaticScheduleStep; if (mStaticScheduleStep == mStaticSchedule.size()) { mStaticScheduleStep = 0; } } void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const { auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose); if (!fp) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Could not create scheduling diagram log file: {}", fileName + ".mmd"); } fmt::print(fp.get(), "gantt\ndateFormat x\naxisFormat %Q µs\n\n"); if (!mScheduling.empty()) { const std::map<std::shared_ptr<Node>, std::string> namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); const auto globalStart = mScheduling[0].start; for (const auto& element : mScheduling) { auto name = namePtrTable.at(element.node); // Mermaid does not allow : character in task title std::replace(name.begin(), name.end(), ':', '_'); fmt::print(fp.get(), "{} :{}, {}\n", name, std::chrono::duration_cast<std::chrono::microseconds>(element.start - globalStart).count(), std::chrono::duration_cast<std::chrono::microseconds>(element.end - globalStart).count()); } } fmt::print(fp.get(), "\n"); } void Aidge::SequentialScheduler::saveStaticSchedulingDiagram(const std::string& fileName) const { auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose); if (!fp) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Could not create scheduling diagram log file: {}", fileName + ".mmd"); } fmt::print(fp.get(), "gantt\ndateFormat x\naxisFormat %Q\n\n"); if (!mStaticSchedule.empty()) { const std::map<std::shared_ptr<Node>, std::string> namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); for (const auto& schedule : mStaticSchedule) { for (const auto& element : schedule) { auto name = namePtrTable.at(element->node); // Mermaid does not allow : character in task title std::replace(name.begin(), name.end(), ':', '_'); fmt::print(fp.get(), "{} :{}, {}\n", name, element->early, element->late); } } } fmt::print(fp.get(), "\n"); } std::vector<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getStaticScheduling(size_t step) const { const auto& staticSchedule = mStaticSchedule.at(step); std::vector<std::shared_ptr<Node>> schedule; std::transform(staticSchedule.begin(), staticSchedule.end(), std::back_inserter(schedule), [](const auto& v) { return v->node; }); return schedule; } std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers( const std::set<std::shared_ptr<Node>>& producers) const { std::set<std::shared_ptr<Node>> consumers; for (const auto& producer : producers) { const auto& childs = producer->getChildren(); for (const auto& child : childs) { // Do not schedule childs outside current graph! if (mGraphView->inView(child)) { consumers.insert(child); } } } return consumers; } Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const { const auto parent = node->inputs()[inputIdx]; if (parent.first) { // Parent is connected, everything if fine! return parent.first->getOperator()->getNbProducedData(parent.second); } else if (std::shared_ptr<Node> upperNode = mUpperNode.lock()) { // We are inside an upper operator (for instance a MetaOperator) // We need to connect the "local" producer-consumer model to the upper // one, by mapping local node inputs to the upper node inputs. IOIndex_t nodeInputIdx = 0; for (const auto& input : mGraphView->getOrderedInputs()) { if (input.first == node) { // Current node is an input const auto upperInput = upperNode->inputs()[nodeInputIdx]; if (upperInput.first) { return upperInput.first->getOperator()->getNbProducedData(upperInput.second); } } ++nodeInputIdx; } } // Otherwise, two cases: if (node->getOperator()->getRawInput(inputIdx)) { // Input is not connected but a valid tensor exists // => This means data was fed manually to the input, without a Producer // In this case, we assume a single-use data (unlike a Producer, which // keep producing the data each time it is needed). fmt::print("No producer node attached to input#{} for node {} ({})\n", inputIdx, node->name(), node->type()); return std::static_pointer_cast<Tensor>(node->getOperator()->getRawInput(inputIdx))->size(); } else { // Input is not connected, this is an error AIDGE_THROW_OR_ABORT(std::runtime_error, "Missing input#{} for node {} ({})\n", inputIdx, node->name(), node->type()); } return 0; } Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::getPriorProducersConsumers( const std::shared_ptr<Node>& node) const { PriorProducersConsumers prior; IOIndex_t inputIdx = 0; for (const auto& parent : node->inputs()) { if (parent.first && (node->getOperator()->getNbConsumedData(inputIdx) + node->getOperator()->getNbRequiredData(inputIdx)) > parent.first->getOperator()->getNbProducedData(parent.second)) { if (!mGraphView->inView(parent.first)) { // Do not schedule prior outside the current graph! return PriorProducersConsumers(); } if (parent.first->type() == Producer_Op::Type) { prior.requiredProducers.insert(parent.first); prior.priorConsumers.insert(node); } else if (parent.first->type() == Memorize_Op::Type) { // Break cycles return PriorProducersConsumers(); } else { const auto& parentPrior = getPriorProducersConsumers(parent.first); if (!parentPrior.isPrior) { return PriorProducersConsumers(); } else { prior.requiredProducers.insert(parentPrior.requiredProducers.cbegin(), parentPrior.requiredProducers.cend()); prior.priorConsumers.insert(parentPrior.priorConsumers.cbegin(), parentPrior.priorConsumers.cend()); } } } ++inputIdx; } prior.isPrior = true; if (prior.priorConsumers.empty()) { prior.priorConsumers.insert(node); } return prior; } void Aidge::ParallelScheduler::forward(bool forwardDims, bool /*verbose*/, std::vector<std::shared_ptr<Aidge::Tensor>> data) { // Collect all data input of the graph (that are producers) if (!data.empty()){ connectInputs(data); } // Forward dims (if allowed) if (forwardDims) {mGraphView->forwardDims(); } // Generate scheduling *only if empty* // If scheduling was already generated (in one or several steps, i.e. one or // several successive call to generateScheduling()), do not generate it twice if (mStaticSchedule.empty()) { this->generateScheduling(); } const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); // Sort static scheduling, the order will be the prefered threads scheduling // order for non critical nodes std::deque<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end()); std::stable_sort(staticSchedule.begin(), staticSchedule.end(), [](const auto& lhs, const auto& rhs) { return ((lhs->early < rhs->early) || (lhs->early == rhs->early && lhs->late < rhs->late)); }); // The thread pool has N threads running to process nodes. // Thread pooling avoid the overhead of threads creation and deletion for // each node execution. ThreadPool pool; size_t latest = 0; std::mutex schedulingMutex; std::map<std::shared_ptr<StaticSchedulingElement>, std::atomic<bool>> finished; while (!staticSchedule.empty()) { Log::debug("Step {}", latest); std::vector<std::shared_ptr<StaticSchedulingElement>> mustFinish; // Run all nodes that must be run at this step: latest (critical nodes) for (size_t i = 0; i < staticSchedule.size(); ) { auto runnable = staticSchedule[i]; if (runnable->late == latest) { // Wait for potential preceding non-critical nodes to finish while (true) { bool ready = true; for (auto elt : runnable->laterThan) { ready = ready && finished.at(elt); } if (!ready) { std::this_thread::yield(); } else { break; } } // Add the critical node to the thread pool queue, to be run ASAP finished[runnable] = false; pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() { const auto tStart = std::chrono::high_resolution_clock::now(); node->forward(); const auto tEnd = std::chrono::high_resolution_clock::now(); finished = true; { std::unique_lock<std::mutex> lock(schedulingMutex); mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd)); } }); staticSchedule.erase(staticSchedule.begin() + i); mustFinish.push_back(runnable); Log::debug(" run critical {}", namePtrTable.at(runnable->node)); // Ensure the following nodes cannot start earlier than next step for (auto elt : runnable->earlierThan) { if (elt->early < latest + 1) { Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1); elt->early = latest + 1; AIDGE_INTERNAL_ASSERT(elt->early <= elt->late); } } } else if (runnable->early > latest + 1) { // There cannot be more node that must be run at latest + 1 // latest + 1 and not latest because early may have been updated // for some elements to latest + 1 (above). break; } else { ++i; } } // If some threads are still available, run next early nodes // These nodes are non-critical, meaning they can still be run at least // in the next step for (size_t i = 0; i < staticSchedule.size(); ) { auto runnable = staticSchedule[i]; if (!pool.busy() && runnable->early <= latest) { // Check that potential preceding non-critical nodes are finished bool ready = true; for (auto elt : runnable->laterThan) { ready = ready && finished.at(elt); } if (ready) { // All preceding nodes have finished, this node can be run. // Add the node to the thread pool queue, to be run ASAP finished[runnable] = false; pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() { const auto tStart = std::chrono::high_resolution_clock::now(); node->forward(); const auto tEnd = std::chrono::high_resolution_clock::now(); finished = true; { std::unique_lock<std::mutex> lock(schedulingMutex); mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd)); } }); staticSchedule.erase(staticSchedule.begin() + i); Log::debug(" run {}", namePtrTable.at(runnable->node)); // Ensure the following nodes cannot start earlier than next step for (auto elt : runnable->earlierThan) { if (elt->early < latest + 1) { Log::debug(" {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1); elt->early = latest + 1; AIDGE_INTERNAL_ASSERT(elt->early <= elt->late); } } } else { // The node cannot be run yet, because preceding nodes are // still running, move to the next one in schedule ++i; } } else { // Thread pool is already full or no more node can be run at // this step (latest) break; } } // Wait for all nodes that must finish at latest to be finished // By scheduling construction, no other node can be started before all // nodes at latest step are finished while (true) { bool ready = true; for (auto elt : mustFinish) { ready = ready && finished.at(elt); } if (!ready) { std::this_thread::yield(); } else { break; } } ++latest; } ++mStaticScheduleStep; if (mStaticScheduleStep == mStaticSchedule.size()) { mStaticScheduleStep = 0; } }