/******************************************************************************** * Copyright (c) 2023 CEA-List * * This program and the accompanying materials are made available under the * terms of the Eclipse Public License 2.0 which is available at * http://www.eclipse.org/legal/epl-2.0. * * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ #include "aidge/scheduler/Scheduler.hpp" #include <chrono> #include <memory> #include <set> #include <string> #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Node.hpp" #include "aidge/utils/Types.h" void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") { putchar('['); int pos = static_cast<int>(barWidth * progress); for (int i = 0; i < barWidth; ++i) { if (i <= pos) putchar('#'); else putchar(' '); } printf("] %d%% | %s\r", static_cast<int>(progress * 100), additionalInfo.c_str()); fflush(stdout); } void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // setup initial producers list mComputationNumber = 0; std::set<std::shared_ptr<Node>> producers; for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) { if (nodePtr->type() == "Producer") { producers.insert(nodePtr); } else { ++mComputationNumber; } } // add Data Input // FIXME : should be changed when the real system for providing // data is implemented for (const std::shared_ptr<Node>& nodePtr : mGraphView->inputNodes()) { for (const auto& parentPtr : nodePtr->getParents()) { if ((mGraphView->getNodes()).find(parentPtr) == (mGraphView->getNodes()).end()) { // Node not found in the graph, it's an outside producer producers.insert(parentPtr); } } } // setup consumer list // std::set<std::shared_ptr<Node>> consumers = getConsumers(producers); /* It may not be necessary to initialize producer */ std::set<std::shared_ptr<Node>> consumers = mGraphView->inputNodes(); do { // find runnable consumers std::set<std::shared_ptr<Node>> runnableConsumers; if (verbose) printf("List of layers receiving data:\n"); for (const auto& consumer : consumers) { if (verbose) { printf("\t- consumer: " "\x1b[1;37m" "%s" "\x1b[0m" "\n\t\tR/C:\t", (consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str()); for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { printf("%ld/%ld\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), consumer->getOperator()->getNbRequiredData(inId)); } printf("%ld/%ld", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1)); printf("\n\t\tP:\t"); for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) { printf("%ld\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); } printf("%ld", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); printf("\n"); } bool isRunnable = true; IOIndex_t parentID = 0; // FIXME: handle this correctly // Check every input has got enought data to run for (const auto& consumerParent : consumer->dataInputs()) { if (consumerParent.first && consumer->getOperator()->getNbRequiredData(parentID++) > consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)) { // not enough data to run isRunnable = false; break; } } if (isRunnable) { runnableConsumers.insert(consumer); } } // Push consumers in the list of nodes to run and update the consumer producer system for (const auto& runnable : runnableConsumers) { runnable->getOperator()->updateConsummerProducer(); mStaticSchedule.push_back(runnable); } // update producers and consumers list if (verbose) printf("Updating producer and consumer lists...\n"); const auto oldConsumers = consumers; for (const auto& consumer : oldConsumers) { if (verbose) { printf("\t- consumer: %s\n\t\tR/C:\t", (consumer->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(consumer.get()))).c_str()); for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { printf("%ld/%ld\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), consumer->getOperator()->getNbRequiredData(inId)); } printf("%ld/%ld", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1)); printf("\n\t\tP:\t"); for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) { printf("%ld\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); } printf("%ld", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); printf("\n"); } bool isStillConsumer = false; IOIndex_t parentID = 0; // FIXME: handle this correctly // should we check input or dataInput ? for (const auto& consumerParent : consumer->inputs()) { if (consumerParent.first && consumer->getOperator()->getNbConsumedData(parentID++) < consumerParent.first->getOperator()->getNbProducedData(consumerParent.second)) { // there is still data to consume isStillConsumer = true; break; } } for (IOIndex_t outId = 0; outId < consumer->nbOutputs(); ++outId) { if (consumer->getOperator()->getNbProducedData(outId) > 0) { if (verbose) printf(" also producer\n"); // make sure consumer is also a producer producers.insert(consumer); const auto& childs = consumer->getChildren(); consumers.insert(childs.begin(), childs.end()); break; } } if (!isStillConsumer) { if (verbose) printf(" no more consumer\n"); // consumer is no longer a consumer, only a producer consumers.erase(consumer); } } if (verbose) printf("*************\n"); } while (!consumers.empty()); } // TODO: handle multiple inputs/outputs void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { if (forwardDims) {mGraphView->forwardDims(); } // add each Producer Node. std::set<std::shared_ptr<Node>> computationOver; mScheduling.clear(); this->generateScheduling(); // TODO: For loop on the list of node to run // run sequencially every runnable consumers once // TODO: handle memory allocation in scheduler // TODO: optimize memory usage for (const auto& runnable : mStaticSchedule) { bool computationOverForConsumer = true; for (IOIndex_t parentIDi = 0; parentIDi < runnable->nbInputs(); ++parentIDi) { if (runnable->getOperator()->getNbConsumedData(parentIDi) < runnable->getOperator()->getNbRequiredData(parentIDi)) { computationOverForConsumer = false; break; } } if (computationOverForConsumer) { computationOver.insert(runnable); } if (verbose) printf("run: %s\n", (runnable->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))).c_str()); else drawProgressBar(static_cast<float>(computationOver.size()) / static_cast<float>(mComputationNumber), 50, (std::string("running ") + runnable->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable.get())))); const auto tStart = std::chrono::high_resolution_clock::now(); runnable->forward(); const auto tEnd = std::chrono::high_resolution_clock::now(); mScheduling.push_back(SchedulingElement(runnable, tStart, tEnd)); } if (!verbose) drawProgressBar(1.0, 50, " "); printf("\n"); } void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const { FILE* fp = std::fopen((fileName + ".mmd").c_str(), "w"); std::fprintf(fp, "gantt\ndateFormat x\naxisFormat %%s ms\n\n"); if (!mScheduling.empty()) { const auto globalStart = mScheduling[0].start; for (const auto& element : mScheduling) { std::fprintf(fp, "%s :%ld, %ld\n", (element.node->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(element.node.get()))) .c_str(), std::chrono::duration_cast<std::chrono::microseconds>(element.start - globalStart).count(), std::chrono::duration_cast<std::chrono::microseconds>(element.end - globalStart).count()); } } std::fprintf(fp, "\n"); std::fclose(fp); } std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers( const std::set<std::shared_ptr<Node>>& producers) const { std::set<std::shared_ptr<Node>> consumers; for (const auto& producer : producers) { const auto& childs = producer->getChildren(); consumers.insert(childs.begin(), childs.end()); } return consumers; }