From 4e6c97893f2a3ad9d6a2bc6ab6c3c47e45d0f8a9 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Wed, 9 Oct 2024 17:14:52 +0200 Subject: [PATCH] Removed nothingtodohere file --- src/scheduler/Scheduler copy.old | 767 ------------------------------- 1 file changed, 767 deletions(-) delete mode 100644 src/scheduler/Scheduler copy.old diff --git a/src/scheduler/Scheduler copy.old b/src/scheduler/Scheduler copy.old deleted file mode 100644 index 0c3b20515..000000000 --- a/src/scheduler/Scheduler copy.old +++ /dev/null @@ -1,767 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include "aidge/scheduler/Scheduler.hpp" - -#include <algorithm> // std::find, std::find_if, std::max, std::min, std::replace, std::transform -#include <cassert> -#include <chrono> -#include <cstddef> // std::size_t -#include <cstdio> // std::fclose, std::fopen -#include <iterator> // std::back_inserter, std::distance -#include <map> -#include <memory> -#include <set> -#include <string> -#include <vector> - -#include <fmt/core.h> -#include <fmt/color.h> -#include <fmt/ranges.h> - -#include "aidge/graph/GraphView.hpp" -#include "aidge/graph/Node.hpp" -#include "aidge/operator/Memorize.hpp" -#include "aidge/operator/MetaOperator.hpp" -#include "aidge/operator/OperatorTensor.hpp" -#include "aidge/operator/Producer.hpp" -#include "aidge/utils/Log.hpp" -#include "aidge/utils/Types.h" - - -Aidge::Scheduler::~Scheduler() noexcept = default; -Aidge::Scheduler::PriorProducersConsumers::PriorProducersConsumers() = default; -Aidge::Scheduler::PriorProducersConsumers::PriorProducersConsumers(const PriorProducersConsumers&) = default; -Aidge::Scheduler::PriorProducersConsumers::~PriorProducersConsumers() noexcept = default; - -void Aidge::Scheduler::generateScheduling() { - auto schedule = generateBaseScheduling(); - generateEarlyLateScheduling(schedule); - mStaticSchedule.push_back(schedule); -} - -std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::Scheduler::generateBaseScheduling() const { - - // 0) setup useful variables - // map associating each node with string "name (type#rank)" - const std::map<std::shared_ptr<Node>, std::string> namePtrTable - = mGraphView->getRankedNodesName("{0} ({1}#{3})"); - - // consumers that were run by but can still consume data. - // They must be run AFTER the remaining consumer to ensure a non-greedy - // producers-consumers model! - std::set<std::shared_ptr<Node>> stillConsumers; - - std::vector<std::shared_ptr<StaticSchedulingElement>> schedule; - - - // 1) Initialize consumers list: - // 1.1) List of the GraphView's input nodes - std::set<std::shared_ptr<Node>> consumers = mGraphView->inputNodes(); - - // 1.2) List of nodes inside the GraphView connected to an inner Producer - std::set<std::shared_ptr<Node>> producers; - for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) { - if (nodePtr->type() == Producer_Op::Type) { - for (const auto& child : nodePtr->getChildren()) { - // Do not schedule childs outside current graph! - if (mGraphView->inView(child)) { - consumers.insert(child); - } - } - } - } - - do { - // 2) From the current consumers list, check if any prior consumer node - // is needed. A prior will generally be required for any node consuming - // parameters (weights and bias) that is not an input node. - // If for a given node, only parent producers (at any depth) are needed - // to satisfy its required data, it becomes a prior. - // If the prior node is a producer, it is added to the list of required - // producers. - // If the prior node is of another type, it replaces the initial consumer - // in the new priorConsumers list. The initial consumer will become - // again a consumer later, by construction. - Log::debug("List of consumers with their priors:"); - std::set<std::shared_ptr<Node>> requiredProducers; // Priors of type Producer - std::set<std::shared_ptr<Node>> priorConsumers; // Priors of other type - mPriorCache.clear(); - - for (const auto& consumer : consumers) { - Log::debug("\t- consumer: {}", fmt::styled(namePtrTable.at(consumer), fg(fmt::color::orange))); - - const auto& prior = getPriorProducersConsumers(consumer); - - if (prior.isPrior) { - std::vector<std::string> requiredProducersName; - std::transform(prior.requiredProducers.begin(), prior.requiredProducers.end(), - std::back_inserter(requiredProducersName), - [&namePtrTable](auto val){ return namePtrTable.at(val); }); - Log::debug("\t\trequired producers: {}", requiredProducersName); - - std::vector<std::string> priorConsumersName; - std::transform(prior.priorConsumers.begin(), prior.priorConsumers.end(), - std::back_inserter(priorConsumersName), - [&namePtrTable](auto val){ return namePtrTable.at(val); }); - Log::debug("\t\tprior consumers: {}", priorConsumersName); - - requiredProducers.insert(prior.requiredProducers.cbegin(), prior.requiredProducers.cend()); - priorConsumers.insert(prior.priorConsumers.cbegin(), prior.priorConsumers.cend()); - } - else { - priorConsumers.insert(consumer); - } - } - - // 3) Prior consumers replace the initial consumers list. - // By construction, initial consumers will necessarily become consumers - // again later. - consumers.swap(priorConsumers); - - // 4) Make producers generate the required data. - // Producers are special nodes that generate data on demand. - for (const auto& requiredProducer : requiredProducers) { - requiredProducer->getOperator()->updateConsummerProducer(); - schedule.push_back(std::make_shared<StaticSchedulingElement>(requiredProducer)); - } - - // 5) Find runnable consumers. - // A consumer is runnable if the required data is available for all of - // its inputs. At this point, not all consumers are necessarily - // runnable because some may depend on the execution of others (when - // there is multiple successive priors for example). - std::set<std::shared_ptr<Node>> runnableConsumers; - Log::debug("Updated list of consumers:"); - for (const auto& consumer : consumers) { - summarizeConsumerState(consumer, namePtrTable.at(consumer)); // debug print - - bool isRunnable = true; - for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) { - AIDGE_LOG_CONTEXT("Consumer node {} input #{}", namePtrTable.at(consumer), inputIdx); - - if ((consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) > - getNbAvailableData(consumer, inputIdx)) { - Log::debug(" not runnable: C{} + R{} > P{} for input #{}", - consumer->getOperator()->getNbConsumedData(inputIdx), - consumer->getOperator()->getNbRequiredData(inputIdx), - getNbAvailableData(consumer, inputIdx), inputIdx); - - // not enough data to run - isRunnable = false; - break; - } - } - - if (isRunnable) { - runnableConsumers.insert(consumer); - } - } - - // 5) If not consumer is runnable, it is a stop condition! - if (runnableConsumers.empty()) { - Log::debug("********************"); - // No consumer is runnable: some required data is missing for all of - // them. There is two possibilities: - // - At least one required data source is exhausted, which may be - // an expected stop condition. - // - There is a deadlock between consumers, if some one is waiting - // for data from the other and reciprocally. - break; - } - - // 6) Push runnable consumers in the list of nodes to run and update the - // consumer producer system. - // At this point, simultaneously runnable consumers have no data - // dependency and could be run in parallel! - for (const auto& runnable : runnableConsumers) { - Log::debug("Runnable: {}", namePtrTable.at(runnable)); - runnable->getOperator()->updateConsummerProducer(); - schedule.push_back(std::make_shared<StaticSchedulingElement>(runnable)); - } - - // 7) Update consumers list - Log::debug("Updating producer and consumer lists..."); - for (const auto& consumer : runnableConsumers) { - summarizeConsumerState(consumer, namePtrTable.at(consumer)); // debug print - // 7.1) If the current consumer has still data to consume, it will - // be put back in the consumers list once the remaining consumers - // have been exhausted. - bool isStillConsumer = false; - // Only look for data inputs. If no data is available on data input, - // by definition, no parameter can be consumed on parameter inputs. - for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) { - if (consumer->inputCategory(inputIdx) == InputCategory::Data) { - AIDGE_LOG_CONTEXT("Consumer node {} input #{}", namePtrTable.at(consumer), inputIdx); - - if (consumer->getOperator()->getNbConsumedData(inputIdx) < - getNbAvailableData(consumer, inputIdx)) { - Log::debug(" still consumer: C{} < P{} for input #{}", - consumer->getOperator()->getNbConsumedData(inputIdx), - getNbAvailableData(consumer, inputIdx), inputIdx); - - // there is still data to consume - isStillConsumer = true; - break; - } - } - } - - // 7.2) If the current consumer becomes a producer for other nodes, - // its childs become consumers. - bool isProducer = false; - for (IOIndex_t outId = 0; outId < consumer->nbOutputs(); ++outId) { - for (const auto& child : consumer->getChildren(outId)) { - if (child && mGraphView->inView(child)) { - IOIndex_t inputIdx = 0; - for (const auto& childParent : child->getParents()) { - if (childParent == consumer) { - AIDGE_LOG_CONTEXT("Consumer node {} input #{} / Producer node {} output #{}", - namePtrTable.at(child), inputIdx, namePtrTable.at(consumer), outId); - - if (child->getOperator()->getNbConsumedData(inputIdx) < consumer->getOperator()->getNbProducedData(outId)) { - isProducer = true; - break; - } - } - ++inputIdx; - } - - if (isProducer) { - break; - } - } - } -/* - if (consumer->getOperator()->getNbProducedData(outId) > 0) { - Log::debug(" also producer"); - // make sure consumer is also a producer - producers.insert(consumer); - - const auto& newConsumers = getConsumers({consumer}); - consumers.insert(newConsumers.cbegin(), newConsumers.cend()); - break; - } -*/ - } - - consumers.erase(consumer); - - if (isProducer) { - Log::debug(" also producer"); - // make sure consumer is also a producer - producers.insert(consumer); - - const auto& newConsumers = getConsumers({consumer}); - consumers.insert(newConsumers.cbegin(), newConsumers.cend()); - } - - if (isStillConsumer) { - // If there is still data to consume, the consumer will be - // run AFTER the other remaining consumers - // (= non-greedy consumers) - stillConsumers.insert(consumer); - } - } - - // 8) If there is no more consumers, swap with possible "still consumers" - // This ensures that the "non-greedy" consumer behavior - if (consumers.empty()) { - consumers.swap(stillConsumers); - stillConsumers.clear(); - } - - Log::debug("********************"); - } while (!consumers.empty()); - - mPriorCache.clear(); - - if (!consumers.empty()) { - std::vector<std::string> consumersName; - std::transform(consumers.begin(), consumers.end(), - std::back_inserter(consumersName), - [&namePtrTable](auto val){ return namePtrTable.at(val); }); - - Log::warn("Remaining consumers: {}. Possible dead-lock.", consumersName); - } - - return schedule; -} - - -void Aidge::Scheduler::summarizeConsumerState(const std::shared_ptr<Aidge::Node>& consumer, const std::string& nodeName) const { - Log::debug("\t- consumer: {}", fmt::styled(nodeName, fg(fmt::color::orange))); - std::string crLog = "\t\tC/R:\t"; - for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { - crLog += fmt::format("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), - consumer->getOperator()->getNbRequiredData(inId)); - } - crLog += fmt::format("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), - consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1)); - Log::debug("{}", crLog); - - std::string pLog = "\t\tP:\t"; - for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) { - pLog += fmt::format("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); - } - pLog += fmt::format("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); - Log::debug("{}", pLog); -} - - -void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const { - std::size_t latest = 0; - // Calculate early (logical) start - for (std::size_t elt = 0; elt < schedule.size(); ++elt) { - const auto node = schedule[elt]->node; - const auto itNode = std::find_if(schedule.rend() - elt, schedule.rend(), - [node](const auto& v) { return (v->node == node); }); - - // Node can be run the earliest just after its childs were run the last time! - std::size_t early = 0; - if (itNode != schedule.rend()) { - for (const auto& child : node->getChildren()) { - // Find child node next scheduled position - const auto it = std::find_if(schedule.rend() - elt, itNode, - [child](const auto& v) { return (v->node == child); }); - AIDGE_INTERNAL_ASSERT(it != schedule.rend()); - - const std::size_t step = std::distance(schedule.begin(), it.base()) - 1; - early = std::max(early, schedule[step]->early + 1); - schedule[step]->earlierThan.push_back(schedule[elt]); - } - } - - // Node can be run the earliest just after its latest parent was run - for (const auto& parent : node->getParents()) { - // Find parent node latest scheduled position - const auto it = std::find_if(schedule.rend() - elt, schedule.rend(), - [parent](const auto& v) { return (v->node == parent); }); - if (it != schedule.rend()) { - const std::size_t step = std::distance(schedule.begin(), it.base()) - 1; - early = std::max(early, schedule[step]->early + 1); - schedule[step]->earlierThan.push_back(schedule[elt]); - } - } - - latest = std::max(latest, early); - schedule[elt]->early = early; - } - - // Calculate late (logical) start - for (std::size_t elt = schedule.size(); elt-- != 0; ) { - const auto node = schedule[elt]->node; - const auto itNode = std::find_if(schedule.begin() + elt + 1, schedule.end(), - [node](const auto& v) { return (v->node == node); }); - - // Node can be run the latest just before its parents are run the next time! - std::size_t late = latest; - if (itNode != schedule.end()) { - for (const auto& parent : node->getParents()) { - // Find child node next scheduled position - const auto it = std::find_if(schedule.begin() + elt + 1, itNode, - [parent](const auto& v) { return (v->node == parent); }); - AIDGE_INTERNAL_ASSERT(it != schedule.end()); - - const std::size_t step = std::distance(schedule.begin(), it); - late = std::min(late, schedule[step]->late - 1); - schedule[step]->laterThan.push_back(schedule[elt]); - } - } - - // Node can be run the latest just before its earliest child is run - for (const auto& child : node->getChildren()) { - // Find child node earliest scheduled position - const auto it = std::find_if(schedule.begin() + elt + 1, schedule.end(), - [child](const auto& v) { return (v->node == child); }); - if (it != schedule.end()) { - const std::size_t step = std::distance(schedule.begin(), it); - late = std::min(late, schedule[step]->late - 1); - schedule[step]->laterThan.push_back(schedule[elt]); - } - } - - schedule[elt]->late = late; - } -} - -void Aidge::Scheduler::resetScheduling() { - for (auto node : mGraphView->getNodes()) { - node->getOperator()->resetConsummerProducer(); - } - - mStaticSchedule.clear(); - mStaticScheduleStep = 0; - mScheduling.clear(); -} - -/** - * This version is a simplified version without special handling of concatenation. -*/ -Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wrapAroundBuffer) const { - MemoryManager memManager; - - for (std::size_t step = 0; step < mStaticSchedule.size(); ++step) { - for (const auto& node : getStaticScheduling(step)) { - if (!incProducers && node->type() == Producer_Op::Type) { - memManager.releaseDependencies(node); - continue; - } - - const auto childs = node->getChildren(); - AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, - "Operator must be of Tensor type for node {} (of type {}).", - node->name(), node->type()); - const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator()); - - std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane; - - // Allocate a memory plane for each node's output - for (IOIndex_t outputIdx = 0; outputIdx < node->nbOutputs(); ++outputIdx) { - const auto requiredSize = op->getRequiredMemory(outputIdx, {}); - AIDGE_ASSERT(requiredSize.type == Elts_t::Data, - "Cannot generate memory with token-based producer-consumer model for node {} (of type {}).", - node->name(), node->type()); - - // By default, specifies a fully monolithic memory block - std::size_t size = requiredSize.data; - std::size_t stride = 0; - std::size_t length = 1; - std::size_t count = 1; - - if (op->getOutput(outputIdx) && op->getOutput(outputIdx)->dims().size() > 3) { - // If it is possible, assume a NCHW layout - size = op->getOutput(outputIdx)->dims().end()[-3]; - stride = size; - length = op->getOutput(outputIdx)->dims().end()[-1]; - count = op->getOutput(outputIdx)->dims().end()[-2]; - } - - // Check if wrap around buffer is possible for this node - // (re-using previous node outputs memory for this node outputs). - // => only if this node is the only child of its parent(s) - std::size_t wrapAroundSize = 0; - std::size_t wrapAroundExtra = 0; - wrapAroundMemPlane.push_back(nullptr); - - // Select the best parent among all allocable nodes for - // reallocation, which is the one with most memory (in order - // to minimize the reallocation size). - IOIndex_t inputIdx = 0; - for (const auto& parent : node->dataInputs()) { - if (parent.first && parent.first->getChildren(parent.second).size() == 1 - // there might be no existing plane if the parent was - // not yet scheduled (because it may be a recurrent connection) - && memManager.getNbPlanes(parent.first) >= parent.first->nbOutputs() - // memSpace should not be already released - && memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second].memSpace->released == -1) - { - const auto requiredData = op->getNbRequiredData(inputIdx); - const auto requiredProtected = op->getNbRequiredProtected(inputIdx); - AIDGE_ASSERT(requiredData.type == Elts_t::Data && requiredProtected.type == Elts_t::Data, - "Cannot generate memory with token-based producer-consumer model for node {} (of type {}).", - node->name(), node->type()); - - const bool isWrappable = (requiredProtected.data < requiredData.data); - const MemoryManager::MemoryPlane& memPlane = memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second]; - - if (isWrappable || !memManager.isWrapAround( - memPlane.memSpace, - memPlane.getFinalOffset() - - memPlane.memSpace->offset, - requiredSize.data)) - { - if (memPlane.getSize() > wrapAroundSize + requiredProtected.data - && std::find(wrapAroundMemPlane.begin(), wrapAroundMemPlane.end(), &memPlane) == wrapAroundMemPlane.end()) - { - wrapAroundSize = memPlane.getSize() - requiredProtected.data; - if (requiredSize.data > wrapAroundSize) { - wrapAroundExtra = requiredSize.data - wrapAroundSize; - } - wrapAroundMemPlane[outputIdx] = &memPlane; - } - - if (wrapAroundExtra == 0) { - break; - } - } - } - ++inputIdx; - } - - // MemoryPlane to (re)use - const MemoryManager::MemoryPlane& memPlane - = (wrapAroundBuffer && wrapAroundSize > 0) - ? (*wrapAroundMemPlane[outputIdx]) : - memManager.allocate(size, childs, stride, length, count); - - if (wrapAroundBuffer && wrapAroundSize > 0) { - memManager.reallocate(memPlane, - node, 0, - size, true, wrapAroundExtra, childs, stride, length, count); - } - else { - memManager.reallocate(memPlane.memSpace, - node, memPlane.offset, - size, false, 0, childs, stride, length, count); - } - } - - memManager.releaseDependencies(node); - memManager.tick(); - } - } - - return memManager; -} - -void Aidge::Scheduler::connectInputs(const std::vector<std::shared_ptr<Aidge::Tensor>>& data){ - // This version of connect inputs only connects tensor inputs in input data producers. - auto inputNodes = mGraphView->getOrderedInputs(); - - std::size_t i = 0; - for (auto& input : inputNodes) { - if (i < data.size() && data[i]) { - // TODO : maybe shallow copy instead of deepcopy - input.first->getOperator()->setInput(input.second, data[i]); - } - else { - const auto& currentTensorPtr = - std::dynamic_pointer_cast<OperatorTensor>(input.first->getOperator())->getInput(input.second); - const bool optional = (input.first->inputCategory(input.second) == InputCategory::OptionalData - || input.first->inputCategory(input.second) == InputCategory::OptionalParam); - - if (currentTensorPtr) { - Log::debug("connectInputs(): existing tensor dims are {} for graph input#{} for input#{} of node {} (of type {})", - i, input.second, input.first->name(), input.first->type(), currentTensorPtr->dims()); - } - else if (!optional) { - Log::warn("connectInputs(): did not specify tensor for mandatory graph input#{} for input#{} of node {} (of type {})", - i, input.second, input.first->name(), input.first->type()); - } - } - ++i; - } -} - -void Aidge::Scheduler::saveSchedulingDiagram(const std::string& fileName) const { - auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose); - - if (!fp) { - AIDGE_THROW_OR_ABORT(std::runtime_error, - "Could not create scheduling diagram log file: {}", fileName + ".mmd"); - } - - fmt::print(fp.get(), "gantt\ndateFormat x\naxisFormat %Q µs\n\n"); - - if (!mScheduling.empty()) { - const std::map<std::shared_ptr<Node>, std::string> namePtrTable - = mGraphView->getRankedNodesName("{0} ({1}#{3})"); - const auto globalStart = mScheduling[0].start; - - for (const auto& element : mScheduling) { - auto name = namePtrTable.at(element.node); - // Mermaid does not allow : character in task title - std::replace(name.begin(), name.end(), ':', '_'); - - fmt::print(fp.get(), "{} :{}, {}\n", - name, - std::chrono::duration_cast<std::chrono::microseconds>(element.start - globalStart).count(), - std::chrono::duration_cast<std::chrono::microseconds>(element.end - globalStart).count()); - } - } - - fmt::print(fp.get(), "\n"); -} - -void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName) const { - auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose); - - if (!fp) { - AIDGE_THROW_OR_ABORT(std::runtime_error, - "Could not create scheduling diagram log file: {}", fileName + ".mmd"); - } - - fmt::print(fp.get(), "gantt\ndateFormat x\naxisFormat %Q\n\n"); - - if (!mStaticSchedule.empty()) { - const std::map<std::shared_ptr<Node>, std::string> namePtrTable - = mGraphView->getRankedNodesName("{0} ({1}#{3})"); - - for (const auto& schedule : mStaticSchedule) { - for (const auto& element : schedule) { - auto name = namePtrTable.at(element->node); - // Mermaid does not allow : character in task title - std::replace(name.begin(), name.end(), ':', '_'); - - fmt::print(fp.get(), "{} :{}, {}\n", - name, element->early, element->late); - } - } - } - - fmt::print(fp.get(), "\n"); -} - -std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getStaticScheduling(std::size_t step) const { - AIDGE_ASSERT(!mStaticSchedule.empty(), "Scheduler::getStaticScheduling(): static scheduling is empty, did you generate scheduling first?"); - AIDGE_ASSERT(step < mStaticSchedule.size(), "Scheduler::getStaticScheduling(): no static scheduling at step {} (available steps: {})", mStaticSchedule.size(), step); - - const auto& staticSchedule = mStaticSchedule.at(step); - std::vector<std::shared_ptr<Node>> schedule; - std::transform(staticSchedule.begin(), staticSchedule.end(), std::back_inserter(schedule), [](const auto& v) { return v->node; }); - return schedule; -} - -std::set<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getConsumers( - const std::set<std::shared_ptr<Node>>& producers) const { - std::set<std::shared_ptr<Node>> consumers; - - for (const auto& producer : producers) { - const auto& childs = producer->getChildren(); - for (const auto& child : childs) { - // Do not schedule childs outside current graph! - if (mGraphView->inView(child)) { - consumers.insert(child); - } - } - } - - return consumers; -} - -Aidge::Elts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const { - const auto parent = node->inputs()[inputIdx]; - Log::info("getNbAvailableData(): input#{} of node {} (of type {})", inputIdx, node->name(), node->type()); - - if (parent.first) { - // Parent is connected, everything if fine! - Log::info("getNbAvailableData(): output#{} of node {} (of type {})", parent.second, parent.first->name(), parent.first->type()); - return parent.first->getOperator()->getNbProducedData(parent.second); - } - else if (std::shared_ptr<Node> upperNode = mUpperNode.lock()) { - if (upperNode != node) { - AIDGE_LOG_CONTEXT("Mapping input#{} for node {} ({}) to upper node {} ({}) inputs", - inputIdx, node->name(), node->type(), upperNode->name(), upperNode->type()); - - // We are inside an upper operator (for instance a MetaOperator) - // We need to connect the "local" producer-consumer model to the upper - // one, by mapping local node inputs to the upper node inputs. - IOIndex_t upperInputIdx = 0; - for (const auto& input : mGraphView->getOrderedInputs()) { - if (input.first == node && input.second == inputIdx) { - // Current node is an input - return getNbAvailableData(upperNode, upperInputIdx); - } - ++upperInputIdx; - } - - // This should not happen! - AIDGE_INTERNAL_ASSERT(true); - } - } - - // Otherwise, it means that the input is not connected. Two cases: - // - There is no data, it is assumed to be an optional input - // - A valid tensor exists: - if (node->getOperator()->getRawInput(inputIdx)) { - // => This means data was fed manually to the input, without a Producer - // In this case, we assume a single-use data (unlike a Producer, which - // keep producing the data each time it is needed). - Log::warn("No producer node attached to input#{} for node {} ({})", inputIdx, node->name(), node->type()); - return Elts_t::DataElts(std::static_pointer_cast<Tensor>(node->getOperator()->getRawInput(inputIdx))->size()); - } - - // This should normally not happen because checks are already made earlier in forwardDims() - AIDGE_ASSERT((node->getOperator()->inputCategory(inputIdx) != InputCategory::Data - && node->getOperator()->inputCategory(inputIdx) != InputCategory::Param), - "No data provided to mandatory input#{} for node {} ({})", - inputIdx, node->name(), node->type()); - - return Elts_t::NoneElts(); -} - -Aidge::Scheduler::PriorProducersConsumers Aidge::Scheduler::getPriorProducersConsumers( - const std::shared_ptr<Node>& node) const -{ - const auto priorCache = mPriorCache.find(node); - if (priorCache != mPriorCache.end()) { - return priorCache->second; - } - - PriorProducersConsumers prior; - - IOIndex_t inputIdx = 0; - for (const auto& parent : node->inputs()) { - NodePtr passthroughParent; - IOIndex_t passthroughInputIdx; - - if (parent.first) { - passthroughParent = parent.first; - passthroughInputIdx = parent.second; - } - else if (std::shared_ptr<Node> upperNode = mUpperNode.lock()) { - IOIndex_t upperInputIdx = 0; - for (const auto& input : mGraphView->getOrderedInputs()) { - if (input.first == node && input.second == inputIdx) { - passthroughParent = upperNode; - passthroughInputIdx = upperInputIdx; - break; - } - ++upperInputIdx; - } - } - - if (passthroughParent) { - AIDGE_LOG_CONTEXT("Producer node {} (of type {}) output #{}", - passthroughParent->name(), passthroughParent->type(), passthroughInputIdx); - - if ((node->getOperator()->getNbConsumedData(inputIdx) + node->getOperator()->getNbRequiredData(inputIdx)) > - passthroughParent->getOperator()->getNbProducedData(passthroughInputIdx)) - { - // the node needs more data than the current parent has provided yet - if (!mGraphView->inView(passthroughParent)) { - // Do not schedule prior outside the current graph! - // return PriorProducersConsumers(); // not scheduled - prior.priorConsumers.insert(node); - } - - else if (passthroughParent->type() == Producer_Op::Type) { - prior.requiredProducers.insert(passthroughParent); - prior.priorConsumers.insert(node); - } - else if (passthroughParent->type() == Memorize_Op::Type) { - // Break cycles - return PriorProducersConsumers(); // not scheduled - } - else { - const auto& parentPrior = getPriorProducersConsumers(passthroughParent); - - if (!parentPrior.isPrior) { - return PriorProducersConsumers(); // not scheduled - } - else { - prior.requiredProducers.insert(parentPrior.requiredProducers.cbegin(), parentPrior.requiredProducers.cend()); - prior.priorConsumers.insert(parentPrior.priorConsumers.cbegin(), parentPrior.priorConsumers.cend()); - } - } - } - } - ++inputIdx; - } - - prior.isPrior = true; - if (prior.priorConsumers.empty()) { - prior.priorConsumers.insert(node); - } - mPriorCache.insert(std::make_pair(node, prior)); - return prior; -} -- GitLab