Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Scheduler.cpp 10.33 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include "aidge/scheduler/Scheduler.hpp"

#include <chrono>
#include <memory>
#include <set>
#include <string>

#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"

void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") {
    putchar('[');
    int pos = static_cast<int>(barWidth * progress);
    for (int i = 0; i < barWidth; ++i) {
        if (i <= pos)
            putchar('#');
        else
            putchar(' ');
    }
    printf("] %d%% | %s\r", static_cast<int>(progress * 100), additionalInfo.c_str());
    fflush(stdout);
}

void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
    // setup initial producers list
    mComputationNumber = 0;
    std::set<std::shared_ptr<Node>> producers;
    for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) {
        if (nodePtr->type() == "Producer") {
            producers.insert(nodePtr);
        } else {
            ++mComputationNumber;
        }
    }
    // add Data Input
    // FIXME : should be changed when the real system for providing
    // data is implemented
    for (const std::shared_ptr<Node>& nodePtr : mGraphView->inputNodes()) {
        for (const auto& parentPtr : nodePtr->getParents()) {
            if ((mGraphView->getNodes()).find(parentPtr) == (mGraphView->getNodes()).end()) {
                // Node not found in the graph, it's an outside producer
                producers.insert(parentPtr);
            }
        }
    }

    // setup consumer list
    // std::set<std::shared_ptr<Node>> consumers = getConsumers(producers);

    /* It may not be necessary to initialize producer */
    std::set<std::shared_ptr<Node>> consumers = mGraphView->inputNodes();
    do {
        // find runnable consumers
        std::set<std::shared_ptr<Node>> runnableConsumers;
        if (verbose) printf("List of layers receiving data:\n");
        for (const auto& consumer : consumers) {
            if (verbose) {
                printf("\t- consumer: "
                       "\x1b[1;37m"
                       "%s"
                       "\x1b[0m"
                       "\n\t\tR/C:\t",
                       (consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str());
                for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) {
                    printf("%ld/%ld\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId),
                           consumer->getOperator()->getNbRequiredData(inId));
                }
                printf("%ld/%ld", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1),
                       consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1));
                printf("\n\t\tP:\t");
                for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) {
                    printf("%ld\n\t\t\t", consumer->getOperator()->getNbProducedData(outId));
                }
                printf("%ld", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1));
                printf("\n");
            }
            bool isRunnable = true;

            IOIndex_t parentID = 0;  // FIXME: handle this correctly
            // Check every input has got enought data to run
            for (const auto& consumerParent : consumer->dataInputs()) {
                if (consumerParent.first &&
                    consumer->getOperator()->getNbRequiredData(parentID++) >
                            consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)) {
                    // not enough data to run
                    isRunnable = false;
                    break;
                }
            }

            if (isRunnable) {
                runnableConsumers.insert(consumer);
            }
        }

        // Push consumers in the list of nodes to run and update the consumer producer system
        for (const auto& runnable : runnableConsumers) {
            runnable->getOperator()->updateConsummerProducer();
            mStaticSchedule.push_back(runnable);
        }

        // update producers and consumers list
        if (verbose) printf("Updating producer and consumer lists...\n");
        const auto oldConsumers = consumers;

        for (const auto& consumer : oldConsumers) {
            if (verbose) {
                printf("\t- consumer: %s\n\t\tR/C:\t",
                       (consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str());
                for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) {
                    printf("%ld/%ld\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId),
                           consumer->getOperator()->getNbRequiredData(inId));
                }
                printf("%ld/%ld", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1),
                       consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1));
                printf("\n\t\tP:\t");
                for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) {
                    printf("%ld\n\t\t\t", consumer->getOperator()->getNbProducedData(outId));
                }
                printf("%ld", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1));
                printf("\n");
            }
            bool isStillConsumer = false;

            IOIndex_t parentID = 0;  // FIXME: handle this correctly
            // should we check input or dataInput ?
            for (const auto& consumerParent : consumer->inputs()) {
                if (consumerParent.first &&
                    consumer->getOperator()->getNbConsumedData(parentID++) <
                            consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)) {
                    // there is still data to consume
                    isStillConsumer = true;
                    break;
                }
            }

            for (IOIndex_t outId = 0; outId < consumer->nbOutputs(); ++outId) {
                if (consumer->getOperator()->getNbProducedData(outId) > 0) {
                    if (verbose) printf("  also producer\n");
                    // make sure consumer is also a producer
                    producers.insert(consumer);

                    const auto& childs = consumer->getChildren();
                    consumers.insert(childs.begin(), childs.end());
                    break;
                }
            }

            if (!isStillConsumer) {
                if (verbose) printf("  no more consumer\n");
                // consumer is no longer a consumer, only a producer
                consumers.erase(consumer);
            }
        }

        if (verbose) printf("*************\n");
    } while (!consumers.empty());

}

// TODO: handle multiple inputs/outputs
void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
    if (forwardDims) {mGraphView->forwardDims(); }

    // add each Producer Node.
    std::set<std::shared_ptr<Node>> computationOver;

    mScheduling.clear();

    this->generateScheduling();

    // TODO: For loop on the list of node to run
    // run sequencially every runnable consumers once
    // TODO: handle memory allocation in scheduler
    // TODO: optimize memory usage
    for (const auto& runnable : mStaticSchedule) {
        bool computationOverForConsumer = true;
        for (IOIndex_t parentIDi = 0; parentIDi < runnable->nbInputs(); ++parentIDi) {
            if (runnable->getOperator()->getNbConsumedData(parentIDi) <
                runnable->getOperator()->getNbRequiredData(parentIDi)) {
                computationOverForConsumer = false;
                break;
            }
        }
        if (computationOverForConsumer) {
            computationOver.insert(runnable);
        }

        if (verbose)
            printf("run: %s\n",
                    (runnable->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))).c_str());
        else
            drawProgressBar(static_cast<float>(computationOver.size()) / static_cast<float>(mComputationNumber), 50,
                            (std::string("running ") + runnable->type() + "_" +
                                std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))));
        const auto tStart = std::chrono::high_resolution_clock::now();
        runnable->forward();
        const auto tEnd = std::chrono::high_resolution_clock::now();
        mScheduling.push_back(SchedulingElement(runnable, tStart, tEnd));
    }
    if (!verbose) drawProgressBar(1.0, 50, "                                   ");
    printf("\n");

}

void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const {
    FILE* fp = std::fopen((fileName + ".mmd").c_str(), "w");
    std::fprintf(fp, "gantt\ndateFormat x\naxisFormat %%s ms\n\n");

    if (!mScheduling.empty()) {
        const auto globalStart = mScheduling[0].start;

        for (const auto& element : mScheduling) {
            std::fprintf(fp, "%s :%ld, %ld\n",
                         (element.node->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(element.node.get())))
                                 .c_str(),
                         std::chrono::duration_cast<std::chrono::microseconds>(element.start - globalStart).count(),
                         std::chrono::duration_cast<std::chrono::microseconds>(element.end - globalStart).count());
        }
    }

    std::fprintf(fp, "\n");
    std::fclose(fp);
}

std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers(
        const std::set<std::shared_ptr<Node>>& producers) const {
    std::set<std::shared_ptr<Node>> consumers;

    for (const auto& producer : producers) {
        const auto& childs = producer->getChildren();
        consumers.insert(childs.begin(), childs.end());
    }

    return consumers;
}