Skip to content
Snippets Groups Projects
Scheduler.cpp 35.7 KiB
Newer Older
Cyril Moineau's avatar
Cyril Moineau committed
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/scheduler/ThreadPool.hpp"
Cyril Moineau's avatar
Cyril Moineau committed

#include <chrono>
#include <memory>
#include <set>
#include <string>

Olivier BICHLER's avatar
Olivier BICHLER committed
#include <fmt/color.h>
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
Maxence Naud's avatar
Maxence Naud committed
#include "aidge/operator/OperatorTensor.hpp"
Olivier BICHLER's avatar
Olivier BICHLER committed
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/Memorize.hpp"
#include "aidge/operator/MetaOperator.hpp"
Cyril Moineau's avatar
Cyril Moineau committed

void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") {
Cyril Moineau's avatar
Cyril Moineau committed
    putchar('[');
    int pos = static_cast<int>(barWidth * progress);
    for (int i = 0; i < barWidth; ++i) {
        if (i <= pos)
            putchar('#');
        else
            putchar(' ');
    }
Olivier BICHLER's avatar
Olivier BICHLER committed
    fmt::print("] {}% | {}\r", static_cast<int>(progress * 100), additionalInfo);
Cyril Moineau's avatar
Cyril Moineau committed
    fflush(stdout);
}

void Aidge::SequentialScheduler::generateScheduling() {
    auto schedule = generateBaseScheduling();
    generateEarlyLateScheduling(schedule);
    mStaticSchedule.push_back(schedule);
}

std::vector<std::shared_ptr<Aidge::SequentialScheduler::StaticSchedulingElement>> Aidge::SequentialScheduler::generateBaseScheduling() const {
    // TODO: For loop on the list of node to run
    // run sequencially every runnable consumers once
    // TODO: handle memory allocation in scheduler
    // TODO: optimize memory usage

    // 1) Setup initial consumers list:
    // It is the list of input nodes
    std::set<std::shared_ptr<Node>> consumers = mGraphView->inputNodes();
    // Plus the list of nodes inside the graph connected to an inner producer
Cyril Moineau's avatar
Cyril Moineau committed
    std::set<std::shared_ptr<Node>> producers;
    for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) {
Olivier BICHLER's avatar
Olivier BICHLER committed
        if (nodePtr->type() == Producer_Op::Type) {
Cyril Moineau's avatar
Cyril Moineau committed
            producers.insert(nodePtr);
        }
    }
    const auto producersConsumers = getConsumers(producers);
    consumers.insert(producersConsumers.begin(), producersConsumers.end());
Cyril Moineau's avatar
Cyril Moineau committed

    const std::map<std::shared_ptr<Node>, std::string> namePtrTable
        = mGraphView->getRankedNodesName("{0} ({1}#{3})");
    // Still consumers are consumers that were run by can still consume data.
    // They must be run AFTER the remaining consumer to ensure a non-greedy
    // producers-consumers model!
    std::set<std::shared_ptr<Node>> stillConsumers;

    std::vector<std::shared_ptr<StaticSchedulingElement>> schedule;
Cyril Moineau's avatar
Cyril Moineau committed
    do {
        // 2) From the current consumers list, check if any prior consumer node
        // is needed. A prior will generally be required for any node consuming 
        // parameters (weights and bias) that is not an input node.
        // If for a given node, only parent producers (at any depth) are needed
        // to satisfy its required data, it becomes a prior.
        // If the prior node is a producer, it is added to the list of required
        // producers.
        // If the prior node is of another type, it replaces the initial consumer
        // in the new priorConsumers list. The initial consumer will become 
        // again a consumer later, by construction.
        Log::debug("List of consumers with their priors:");
Olivier BICHLER's avatar
Olivier BICHLER committed
        std::set<std::shared_ptr<Node>> requiredProducers;
        std::set<std::shared_ptr<Node>> priorConsumers;
Olivier BICHLER's avatar
Olivier BICHLER committed

        for (const auto& consumer : consumers) {
            Log::debug("\t- consumer: {}", fmt::styled(namePtrTable.at(consumer), fg(fmt::color::orange)));
Olivier BICHLER's avatar
Olivier BICHLER committed

            const auto& prior = getPriorProducersConsumers(consumer);
Olivier BICHLER's avatar
Olivier BICHLER committed

                std::vector<std::string> requiredProducersName;
                std::transform(prior.requiredProducers.begin(), prior.requiredProducers.end(),
                    std::back_inserter(requiredProducersName),
                    [&namePtrTable](auto val){ return namePtrTable.at(val); });
                Log::debug("\t\trequired producers: {}", requiredProducersName);

                std::vector<std::string> priorConsumersName;
                std::transform(prior.priorConsumers.begin(), prior.priorConsumers.end(),
                    std::back_inserter(priorConsumersName),
                    [&namePtrTable](auto val){ return namePtrTable.at(val); });
                Log::debug("\t\tprior consumers: {}", priorConsumersName);
                requiredProducers.insert(prior.requiredProducers.cbegin(), prior.requiredProducers.cend());
                priorConsumers.insert(prior.priorConsumers.cbegin(), prior.priorConsumers.cend());
            }
            else {
                priorConsumers.insert(consumer);
        // 3) Prior consumers replace the initial consumers list.
        // By construction, initial consumers will necessarily become consumers
        // again later.
        // 4) Make producers generate the required data.
        // Producers are special nodes that generate data on demand.
Olivier BICHLER's avatar
Olivier BICHLER committed
        for (const auto& requiredProducer : requiredProducers) {
            requiredProducer->getOperator()->updateConsummerProducer();
            schedule.push_back(std::make_shared<StaticSchedulingElement>(requiredProducer));
        // 5) Find runnable consumers.
        // A consumer is runnable if the required data is available for all of 
        // its inputs. At this point, not all consumers are necessarily
        // runnable because some may depend on the execution of others (when
        // there is multiple successive priors for example).
Cyril Moineau's avatar
Cyril Moineau committed
        std::set<std::shared_ptr<Node>> runnableConsumers;
        Log::debug("Updated list of consumers:");
Cyril Moineau's avatar
Cyril Moineau committed
        for (const auto& consumer : consumers) {
            Log::debug("\t- consumer: {}", fmt::styled(namePtrTable.at(consumer), fg(fmt::color::orange)));

            std::string crLog = "\t\tC/R:\t";
            for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) {
                crLog += fmt::format("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId),
                        consumer->getOperator()->getNbRequiredData(inId));
Cyril Moineau's avatar
Cyril Moineau committed
            }
            crLog += fmt::format("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1),
                    consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1));
            Log::debug("{}", crLog);

            std::string pLog = "\t\tP:\t";
            for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) {
                pLog += fmt::format("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId));
            }
            pLog += fmt::format("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1));
            Log::debug("{}", pLog);

Cyril Moineau's avatar
Cyril Moineau committed
            bool isRunnable = true;
            for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) {
                if (/*consumer->getOperator()->getNbRequiredData(inputIdx) > 0
                    && */(consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) >
                            getNbAvailableData(consumer, inputIdx)) {
                    Log::debug("  not runnable: C{} + R{} > P{} for input #{}",
Olivier BICHLER's avatar
Olivier BICHLER committed
                        consumer->getOperator()->getNbConsumedData(inputIdx),
                        consumer->getOperator()->getNbRequiredData(inputIdx),
                        getNbAvailableData(consumer, inputIdx), inputIdx);
Cyril Moineau's avatar
Cyril Moineau committed
                    // not enough data to run
                    isRunnable = false;
                    break;
                }
            }

            if (isRunnable) {
                runnableConsumers.insert(consumer);
            }
        }

        // 5) If not consumer is runnable, it is a stop condition!
        if (runnableConsumers.empty()) {
            Log::debug("********************");
            // No consumer is runnable: some required data is missing for all of
            // them. There is two possibilities:
            // - At least one required data source is exhausted, which may be
            //   an expected stop condition.
            // - There is a deadlock between consumers, if some one is waiting
            //   for data from the other and reciprocally.
            break;
        }

        // 6) Push runnable consumers in the list of nodes to run and update the
        // consumer producer system.
        // At this point, simultaneously runnable consumers have no data 
        // dependency and could be run in parallel!
Cyril Moineau's avatar
Cyril Moineau committed
        for (const auto& runnable : runnableConsumers) {
            Log::debug("Runnable: {}", namePtrTable.at(runnable));
            runnable->getOperator()->updateConsummerProducer();
            schedule.push_back(std::make_shared<StaticSchedulingElement>(runnable));
        // 7) Update consumers list
        Log::debug("Updating producer and consumer lists...");
        for (const auto& consumer : runnableConsumers) {
            Log::debug("\t- consumer: {}", fmt::styled(namePtrTable.at(consumer), fg(fmt::color::orange)));

            std::string crLog = "\t\tC/R:\t";
            for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) {
                crLog += fmt::format("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId),
                        consumer->getOperator()->getNbRequiredData(inId));
            }
            crLog += fmt::format("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1),
                    consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1));
            Log::debug("{}", crLog);

            std::string pLog = "\t\tP:\t";
            for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) {
                pLog += fmt::format("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId));
Cyril Moineau's avatar
Cyril Moineau committed
            }
            pLog += fmt::format("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1));
            Log::debug("{}", pLog);
Cyril Moineau's avatar
Cyril Moineau committed

            // 7.1) If the current consumer has still data to consume, it will
            // be put back in the consumers list once the remaining consumers
            // have been exhausted.
            bool isStillConsumer = false;
            for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) {
                if (consumer->getOperator()->getNbConsumedData(inputIdx) <
                            getNbAvailableData(consumer, inputIdx)) {
                    Log::debug("  still consumer: C{} < P{} for input #{}",
Olivier BICHLER's avatar
Olivier BICHLER committed
                        consumer->getOperator()->getNbConsumedData(inputIdx),
                        getNbAvailableData(consumer, inputIdx), inputIdx);
Cyril Moineau's avatar
Cyril Moineau committed
                    // there is still data to consume
                    isStillConsumer = true;
                    break;
                }
            }

            // 7.2) If the current consumer becomes a producer for other nodes,
            // its childs become consumers.
            bool isProducer = false;
            for (IOIndex_t outId = 0; outId < consumer->nbOutputs(); ++outId) {
                for (const auto& child : consumer->getChildren(outId)) {
                    if (child) {
                        IOIndex_t inputIdx = 0;
                        for (const auto& childParent : child->getParents()) {
                            if (childParent == consumer) {
                                if (consumer->getOperator()->getNbProducedData(outId) > child->getOperator()->getNbConsumedData(inputIdx)) {
                                    isProducer = true;
                                }
                            }
                            ++inputIdx;
                        }
                    }
                }
/*
Cyril Moineau's avatar
Cyril Moineau committed
                if (consumer->getOperator()->getNbProducedData(outId) > 0) {
                    Log::debug("  also producer");
Cyril Moineau's avatar
Cyril Moineau committed
                    // make sure consumer is also a producer
                    producers.insert(consumer);

                    const auto& newConsumers = getConsumers({consumer});
                    consumers.insert(newConsumers.cbegin(), newConsumers.cend());
Cyril Moineau's avatar
Cyril Moineau committed
                    break;
                }
            consumers.erase(consumer);

            if (isProducer) {
                Log::debug("  also producer");
                // make sure consumer is also a producer
                producers.insert(consumer);

                const auto& newConsumers = getConsumers({consumer});
                consumers.insert(newConsumers.cbegin(), newConsumers.cend());
            }

            if (isStillConsumer) {
                // If there is still data to consume, the consumer will be
                // run AFTER the other remaining consumers
                // (= non-greedy consumers)
                stillConsumers.insert(consumer);
        // 8) If there is no more consumers, swap with possible "still consumers"
        // This ensures that the "non-greedy" consumer behavior
        if (consumers.empty()) {
            consumers.swap(stillConsumers);
            stillConsumers.clear();
        }

        Log::debug("********************");
    if (!consumers.empty()) {
        Log::warn("Remaining consumers: possible dead-lock");
Olivier BICHLER's avatar
Olivier BICHLER committed
    }
    return schedule;
}
void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<std::shared_ptr<StaticSchedulingElement>>& schedule) const {
    size_t latest = 0;
    // Calculate early (logical) start
    for (size_t elt = 0; elt < schedule.size(); ++elt) {
        const auto node = schedule[elt]->node;
        const auto itNode = std::find_if(schedule.rend() - elt, schedule.rend(),
            [node](const auto& v) { return (v->node == node); });
        // Node can be run the earliest just after it was run the last time!
        size_t early = 0;
        if (itNode != schedule.rend()) {
            early = (*itNode)->early + 1;
            (*itNode)->earlierThan.push_back(schedule[elt]);
        // Node can be run the earliest just after its latest parent was run
        for (const auto& parent : node->getParents()) {
            // Find parent node latest scheduled position
            const auto it = std::find_if(schedule.rend() - elt, schedule.rend(),
                [parent](const auto& v) { return (v->node == parent); });
            if (it != schedule.rend()) {
                const size_t step = std::distance(schedule.begin(), it.base()) - 1;
                early = std::max(early, schedule[step]->early + 1);
                schedule[step]->earlierThan.push_back(schedule[elt]);
        latest = std::max(latest, early);
        schedule[elt]->early = early;
    // Calculate late (logical) start
    for (size_t elt = schedule.size(); elt-- != 0; ) {
        const auto node = schedule[elt]->node;
        const auto itNode = std::find_if(schedule.begin() + elt + 1, schedule.end(),
            [node](const auto& v) { return (v->node == node); });
        // Node can be run the latest just before it is run the next time!
        size_t late = latest;
        if (itNode != schedule.end()) {
            late = (*itNode)->late - 1;
        // Node can be run the latest just before its earliest child is run
        for (const auto& child : node->getChildren()) {
            // Find child node earliest scheduled position
            const auto it = std::find_if(schedule.begin() + elt + 1, schedule.end(),
                [child](const auto& v) { return (v->node == child); });
            if (it != schedule.end()) {
                const size_t step = std::distance(schedule.begin(), it);
                late = std::min(late, schedule[step]->late - 1);
        schedule[elt]->late = late;
void Aidge::SequentialScheduler::resetScheduling() {
    for (auto node : mGraphView->getNodes()) {
        node->getOperator()->resetConsummerProducer();
    }

    mStaticSchedule.clear();
    mStaticScheduleStep = 0;
    mScheduling.clear();
}

/**
 * This version is a simplified version without special handling of concatenation.
*/
Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducers, bool wrapAroundBuffer) const {
    MemoryManager memManager;

    for (size_t step = 0; step < mStaticSchedule.size(); ++step) {
        for (const auto& node : getStaticScheduling(step)) {
            if (!incProducers && node->type() == Producer_Op::Type) {
                memManager.releaseDependencies(node);
                continue;
            }
            
            const auto childs = node->getChildren();
            AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
            const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator());

            std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane;

            // Allocate a memory plane for each node's output
            for (IOIndex_t outputIdx = 0; outputIdx < node->nbOutputs(); ++outputIdx) {
                const size_t requiredSize = op->getRequiredMemory(outputIdx, {});

                // By default, specifies a fully monolithic memory block
                size_t size = requiredSize;
                size_t stride = 0;
                size_t length = 1;
                size_t count = 1;

                if (op->getOutput(outputIdx) && op->getOutput(outputIdx)->dims().size() > 3) {
                    // If it is possible, assume a NCHW layout
                    size = op->getOutput(outputIdx)->dims().end()[-3];
                    stride = size;
                    length = op->getOutput(outputIdx)->dims().end()[-1];
                    count = op->getOutput(outputIdx)->dims().end()[-2];
                }
                
                // Check if wrap around buffer is possible for this node
                // (re-using previous node outputs memory for this node outputs).
                // => only if this node is the only child of its parent(s)
                size_t wrapAroundSize = 0;
                size_t wrapAroundExtra = 0;
                wrapAroundMemPlane.push_back(nullptr);

                // Select the best parent among all allocable nodes for 
                // reallocation, which is the one with most memory (in order
                // to minimize the reallocation size).
                IOIndex_t inputIdx = 0;
                for (const auto& parent : node->dataInputs()) {
                    if (parent.first && parent.first->getChildren(parent.second).size() == 1
                        // there might be no existing plane if the parent was
                        // not yet scheduled (because it may be a recurrent connection)
                        && memManager.getNbPlanes(parent.first) >= parent.first->nbOutputs()
                        // memSpace should not be already released
                        && memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second].memSpace->released == -1)
                    {
                        const bool isWrappable = (op->getNbRequiredProtected(inputIdx) < op->getNbRequiredData(inputIdx));
                        const MemoryManager::MemoryPlane& memPlane = memManager.getPlanes(parent.first).end()[-parent.first->nbOutputs()+parent.second];

                        if (isWrappable || !memManager.isWrapAround(
                                    memPlane.memSpace,
                                    memPlane.getFinalOffset()
                                        - memPlane.memSpace->offset,
                                    requiredSize))
                        {
                            if (memPlane.getSize() > wrapAroundSize + op->getNbRequiredProtected(inputIdx)
                                && std::find(wrapAroundMemPlane.begin(), wrapAroundMemPlane.end(), &memPlane) == wrapAroundMemPlane.end())
                            {
                                wrapAroundSize = memPlane.getSize() - op->getNbRequiredProtected(inputIdx);
                                if (requiredSize > wrapAroundSize) {
                                    wrapAroundExtra = requiredSize - wrapAroundSize;
                                }
                                wrapAroundMemPlane[outputIdx] = &memPlane;
                            }

                            if (wrapAroundExtra == 0) {
                                break;
                            }
                        }
                    }
                    ++inputIdx;
                }

                // MemoryPlane to (re)use
                const MemoryManager::MemoryPlane& memPlane
                    = (wrapAroundBuffer && wrapAroundSize > 0)
                        ? (*wrapAroundMemPlane[outputIdx]) :
                            memManager.allocate(requiredSize, childs, stride, length, count);

                if (wrapAroundBuffer && wrapAroundSize > 0) {
                    memManager.reallocate(memPlane,
                        node, 0,
                        requiredSize, true, wrapAroundExtra, childs, stride, length, count);
                }
                else {
                    memManager.reallocate(memPlane.memSpace,
                        node, memPlane.offset,
                        requiredSize, false, 0, childs, stride, length, count);
                }
            }

            memManager.releaseDependencies(node);
            memManager.tick();
        }
    }

    return memManager;
}

void Aidge::SequentialScheduler::connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data){
    // This version of connect inputs only connects tensor inputs in input data producers.
    auto inputNodes = mGraphView->getOrderedInputs();

    // Assert that the number of input data producers corresponds to the number of data input
    assert(data.size() == inputNodes.size()  && "Scheduler connectInput error - Inconsistent number of graph inputs and inputs passed to the graph");
    
    for (std::size_t i = 0; i < data.size(); ++i){
        // TODO : maybe shallow copy instead of deepcopy
        inputNodes[i].first->getOperator()->setInput(inputNodes[i].second, data[i]);
    }
}


void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::vector<std::shared_ptr<Aidge::Tensor>> data) {
    
    // Collect all data input of the graph (that are producers)
    if (!data.empty()){
        connectInputs(data);
    }

    // Forward dims (if allowed)
    if (forwardDims) {mGraphView->forwardDims(); }

    // Generate scheduling *only if empty*
    // If scheduling was already generated (in one or several steps, i.e. one or
    // several successive call to generateScheduling()), do not generate it twice
    if (mStaticSchedule.empty()) {
        this->generateScheduling();
Olivier BICHLER's avatar
Olivier BICHLER committed
    const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
    for (const auto& runnable : getStaticScheduling(mStaticScheduleStep)) {
Olivier BICHLER's avatar
Olivier BICHLER committed
            fmt::print("run: {}\n", namePtrTable.at(runnable));
            drawProgressBar(static_cast<float>(cpt) / static_cast<float>(mStaticSchedule.size()), 50,
Olivier BICHLER's avatar
Olivier BICHLER committed
                            (std::string("running ") + namePtrTable.at(runnable)));
        const auto tStart = std::chrono::high_resolution_clock::now();
        runnable->forward();
        const auto tEnd = std::chrono::high_resolution_clock::now();
        mScheduling.push_back(SchedulingElement(runnable, tStart, tEnd));
Cyril Moineau's avatar
Cyril Moineau committed
    if (!verbose) drawProgressBar(1.0, 50, "                                   ");
Olivier BICHLER's avatar
Olivier BICHLER committed
    fmt::print("\n");
    if (mStaticScheduleStep == mStaticSchedule.size()) {
        mStaticScheduleStep = 0;
    }
Cyril Moineau's avatar
Cyril Moineau committed
}

void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const {
    auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose);
Olivier BICHLER's avatar
Olivier BICHLER committed

    if (!fp) {
        AIDGE_THROW_OR_ABORT(std::runtime_error,
            "Could not create scheduling diagram log file: {}", fileName + ".mmd");
    }

    fmt::print(fp.get(), "gantt\ndateFormat x\naxisFormat %Q µs\n\n");
Cyril Moineau's avatar
Cyril Moineau committed

    if (!mScheduling.empty()) {
        const std::map<std::shared_ptr<Node>, std::string> namePtrTable
            = mGraphView->getRankedNodesName("{0} ({1}#{3})");
Cyril Moineau's avatar
Cyril Moineau committed
        const auto globalStart = mScheduling[0].start;

        for (const auto& element : mScheduling) {
            auto name = namePtrTable.at(element.node);
            // Mermaid does not allow : character in task title
            std::replace(name.begin(), name.end(), ':', '_');

Olivier BICHLER's avatar
Olivier BICHLER committed
            fmt::print(fp.get(), "{} :{}, {}\n",
Cyril Moineau's avatar
Cyril Moineau committed
                         std::chrono::duration_cast<std::chrono::microseconds>(element.start - globalStart).count(),
                         std::chrono::duration_cast<std::chrono::microseconds>(element.end - globalStart).count());
        }
    }

Olivier BICHLER's avatar
Olivier BICHLER committed
    fmt::print(fp.get(), "\n");
void Aidge::SequentialScheduler::saveStaticSchedulingDiagram(const std::string& fileName) const {
    auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose);

    if (!fp) {
        AIDGE_THROW_OR_ABORT(std::runtime_error,
            "Could not create scheduling diagram log file: {}", fileName + ".mmd");
    }

    fmt::print(fp.get(), "gantt\ndateFormat x\naxisFormat %Q\n\n");

    if (!mStaticSchedule.empty()) {
        const std::map<std::shared_ptr<Node>, std::string> namePtrTable
            = mGraphView->getRankedNodesName("{0} ({1}#{3})");

        for (const auto& schedule : mStaticSchedule) {
            for (const auto& element : schedule) {
                auto name = namePtrTable.at(element->node);
                // Mermaid does not allow : character in task title
                std::replace(name.begin(), name.end(), ':', '_');
                fmt::print(fp.get(), "{} :{}, {}\n",
                            name, element->early, element->late);
std::vector<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getStaticScheduling(size_t step) const {
    const auto& staticSchedule = mStaticSchedule.at(step);
    std::vector<std::shared_ptr<Node>> schedule;
    std::transform(staticSchedule.begin(), staticSchedule.end(), std::back_inserter(schedule), [](const auto& v) { return v->node; });
Cyril Moineau's avatar
Cyril Moineau committed
std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers(
        const std::set<std::shared_ptr<Node>>& producers) const {
    std::set<std::shared_ptr<Node>> consumers;

    for (const auto& producer : producers) {
        const auto& childs = producer->getChildren();
        for (const auto& child : childs) {
            // Do not schedule childs outside current graph!
            if (mGraphView->inView(child)) {
                consumers.insert(child);
            }
        }
Cyril Moineau's avatar
Cyril Moineau committed
    }

    return consumers;
Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const {
    const auto parent = node->inputs()[inputIdx];

    if (parent.first) {
        // Parent is connected, everything if fine!
        return parent.first->getOperator()->getNbProducedData(parent.second);
    }
    else if (std::shared_ptr<Node> upperNode = mUpperNode.lock()) {
        // We are inside an upper operator (for instance a MetaOperator)
        // We need to connect the "local" producer-consumer model to the upper
        // one, by mapping local node inputs to the upper node inputs.
        IOIndex_t nodeInputIdx = 0;
        for (const auto& input : mGraphView->getOrderedInputs()) {
            if (input.first == node) {
                // Current node is an input
                const auto upperInput = upperNode->inputs()[nodeInputIdx];
                if (upperInput.first) {
                    return upperInput.first->getOperator()->getNbProducedData(upperInput.second);
                } 
            }
            ++nodeInputIdx;
        }
    }

    // Otherwise, two cases:
    if (node->getOperator()->getRawInput(inputIdx)) {
        // Input is not connected but a valid tensor exists
        // => This means data was fed manually to the input, without a Producer
        // In this case, we assume a single-use data (unlike a Producer, which
        // keep producing the data each time it is needed).
        fmt::print("No producer node attached to input#{} for node {} ({})\n", inputIdx, node->name(), node->type());
        return std::static_pointer_cast<Tensor>(node->getOperator()->getRawInput(inputIdx))->size();
    }
    else {
        // Input is not connected, this is an error
        AIDGE_THROW_OR_ABORT(std::runtime_error, "Missing input#{} for node {} ({})\n", inputIdx, node->name(), node->type());
    }

    return 0;
}

Aidge::SequentialScheduler::PriorProducersConsumers Aidge::SequentialScheduler::getPriorProducersConsumers(
    const std::shared_ptr<Node>& node) const
{
    PriorProducersConsumers prior;

    IOIndex_t inputIdx = 0;
    for (const auto& parent : node->inputs()) {
        if (parent.first &&
            (node->getOperator()->getNbConsumedData(inputIdx) + node->getOperator()->getNbRequiredData(inputIdx)) >
                    parent.first->getOperator()->getNbProducedData(parent.second))
        {
            if (!mGraphView->inView(parent.first)) {
                // Do not schedule prior outside the current graph!
                return PriorProducersConsumers();
            }

            if (parent.first->type() == Producer_Op::Type) {
                prior.requiredProducers.insert(parent.first);
                prior.priorConsumers.insert(node);
            }
            else if (parent.first->type() == Memorize_Op::Type) {
                // Break cycles
                return PriorProducersConsumers();
            }
            else {
                const auto& parentPrior = getPriorProducersConsumers(parent.first);

                if (!parentPrior.isPrior) {
                    return PriorProducersConsumers();
                }
                else {
                    prior.requiredProducers.insert(parentPrior.requiredProducers.cbegin(), parentPrior.requiredProducers.cend());
                    prior.priorConsumers.insert(parentPrior.priorConsumers.cbegin(), parentPrior.priorConsumers.cend());
                }
            }
        }
        ++inputIdx;
    }

    prior.isPrior = true;
    if (prior.priorConsumers.empty()) {
        prior.priorConsumers.insert(node);
    }
    return prior;
}

void Aidge::ParallelScheduler::forward(bool forwardDims, bool /*verbose*/, std::vector<std::shared_ptr<Aidge::Tensor>> data) {
    
    // Collect all data input of the graph (that are producers)
    if (!data.empty()){
        connectInputs(data);
    }

    // Forward dims (if allowed)
    if (forwardDims) {mGraphView->forwardDims(); }

    // Generate scheduling *only if empty*
    // If scheduling was already generated (in one or several steps, i.e. one or
    // several successive call to generateScheduling()), do not generate it twice
    if (mStaticSchedule.empty()) {
        this->generateScheduling();
    }

    const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");

    // Sort static scheduling, the order will be the prefered threads scheduling
    // order for non critical nodes
    std::deque<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end());
    std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
        [](const auto& lhs, const auto& rhs) { return ((lhs->early < rhs->early) || (lhs->early == rhs->early && lhs->late < rhs->late)); });

    // The thread pool has N threads running to process nodes.
    // Thread pooling avoid the overhead of threads creation and deletion for
    // each node execution.
    ThreadPool pool;
    pool.start();

    size_t latest = 0;
    std::mutex schedulingMutex;
    std::vector<int> required(staticSchedule.back()->late + 1, 0);
    std::vector<std::atomic<int>> finished(staticSchedule.back()->late + 1);
    std::fill(finished.begin(), finished.end(), 0);

    while (!staticSchedule.empty()) {
        Log::debug("Step {}", latest);

        // Run all nodes that must be run at latest
        for (size_t i = 0; i < staticSchedule.size(); ) {
            auto runnable = staticSchedule[i];

            if (runnable->late == latest) {
                // Critical path
                pool.queueJob([node = runnable->node, &finished = finished[latest], &schedulingMutex, this]() {
                    const auto tStart = std::chrono::high_resolution_clock::now();
                    node->forward();
                    const auto tEnd = std::chrono::high_resolution_clock::now();
                    ++finished;
                    {
                        std::unique_lock<std::mutex> lock(schedulingMutex);
                        mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd));
                    }
                });
                staticSchedule.erase(staticSchedule.begin() + i);
                ++required[latest];

                Log::debug("  run critical {}", namePtrTable.at(runnable->node));

                for (auto elt : runnable->earlierThan) {
                    if (elt->early < latest + 1) {
                        Log::debug("    {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1);
                        elt->early = latest + 1;
                    }
                }
            }
            else if (runnable->early > latest + 1) {
                // There cannot be more node that must be run at latest + 1
                // latest + 1 and not latest because early may have been updated
                // for some elements to latest + 1 (above).
                break;
            }
            else {
                ++i;
            }
        }

        // If some threads are still available, run next early nodes
        while (!pool.busy() && !staticSchedule.empty()) {
            auto runnable = staticSchedule.front();
            if (runnable->early <= latest) {
                pool.queueJob([node = runnable->node, &finished = finished[runnable->late], &schedulingMutex, this]() {
                    const auto tStart = std::chrono::high_resolution_clock::now();
                    node->forward();
                    const auto tEnd = std::chrono::high_resolution_clock::now();
                    ++finished;
                    {
                        std::unique_lock<std::mutex> lock(schedulingMutex);
                        mScheduling.emplace_back(SchedulingElement(node, tStart, tEnd));
                    }
                });
                staticSchedule.pop_front();
                ++required[runnable->late];

                Log::debug("  run {}", namePtrTable.at(runnable->node));

                for (auto elt : runnable->earlierThan) {
                    if (elt->early < latest + 1) {
                        Log::debug("    {}: {} -> {}", namePtrTable.at(elt->node), elt->early, latest + 1);
                        elt->early = latest + 1;
                    }
                }
            }
            else {
                break;
            }
        }

        // Wait for all nodes that must finish at latest to be finished
        while (finished[latest] < required[latest]) {
            std::this_thread::yield();
        }

        ++latest;
    }

    pool.stop();

    ++mStaticScheduleStep;
    if (mStaticScheduleStep == mStaticSchedule.size()) {
        mStaticScheduleStep = 0;
    }
}