diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index c0cccd5afb4e0cbfe8104ebfe3240d6918a7ce77..2992ef652c0ee632ea4d41c2817d35f487ab44c2 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -33,8 +33,8 @@ private: struct StaticSchedulingElement { StaticSchedulingElement( std::shared_ptr<Node> node_, - size_t early_, - size_t late_) + size_t early_ = static_cast<size_t>(-1), + size_t late_ = static_cast<size_t>(-1)) : node(node_), early(early_), late(late_) {} std::shared_ptr<Node> node; @@ -69,8 +69,17 @@ public: }; ~SequentialScheduler() = default; - void generateScheduling(bool verbose = false); - std::vector<StaticSchedulingElement> generateEarlyLateScheduling() const; + /** + * Generate full static scheduling of the GraphView. + * For each node, an earliest and latest possible execution logical step + * is specified. Nodes that may be scheduled at the same logical step have + * no data dependency and can be run in parallel. + */ + void generateScheduling(); + + /** + * Reset all scheduling and associated nodes producer consumer. + */ void resetScheduling(); /** @@ -92,26 +101,42 @@ public: */ void forward(bool forwardDims = true, bool verbose = false, std::vector<std::shared_ptr<Aidge::Tensor>> data = {}); + /** + * @brief Save in a Markdown file the static scheduling with early and late relative order for the nodes. + * @param fileName Name of the generated file. + */ + void saveStaticSchedulingDiagram(const std::string& fileName) const; + /** * @brief Save in a Markdown file the order of layers execution. * @param fileName Name of the generated file. */ void saveSchedulingDiagram(const std::string& fileName) const; - - void saveStaticSchedulingDiagram(const std::string& fileName, const std::vector<StaticSchedulingElement>& scheduling) const; /** * @brief Return a vector of Node ordered by the order they are called by the scheduler * @return std::vector<std::shared_ptr<Node>> */ - inline std::vector<std::shared_ptr<Node>> getStaticScheduling(size_t step = 0) const noexcept { - return mStaticSchedule.at(step); - } + std::vector<std::shared_ptr<Node>> getStaticScheduling(size_t step = 0) const; inline std::shared_ptr<GraphView> getGraphView() const noexcept { return mGraphView; } private: + /** + * Generate an initial base scheduling for the GraphView. + * The scheduling is entirely sequential and garanteed to be valid w.r.t. + * each node producer-consumer model. + */ + std::vector<StaticSchedulingElement> generateBaseScheduling() const; + + /** + * Fill-in early and late scheduling step from initial base scheduling. + * For each node, specifies the earliest and latest possible execution + * logical step. + */ + void generateEarlyLateScheduling(std::vector<StaticSchedulingElement>& schedule) const; + /** * @brief Set of layers receiving an input from currently processing layers * @@ -129,7 +154,7 @@ private: /** @brief List of SchedulingElement (i.e: Nodes with their computation time) */ std::vector<SchedulingElement> mScheduling; /** @brief List of nodes ordered by their */ - std::vector<std::vector<std::shared_ptr<Node>>> mStaticSchedule; + std::vector<std::vector<StaticSchedulingElement>> mStaticSchedule; size_t mStaticScheduleStep = 0; }; } // namespace Aidge diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp index 170aa6c271a4f08ff5ad2801b754b647fee56df6..4f1b444f997831acefd4b1e24344f6395a46eadc 100644 --- a/python_binding/scheduler/pybind_Scheduler.cpp +++ b/python_binding/scheduler/pybind_Scheduler.cpp @@ -23,7 +23,7 @@ void init_Scheduler(py::module& m){ .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("verbose")=false, py::arg("data")=std::vector<Tensor>()) .def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name")) .def("resetScheduling", &SequentialScheduler::resetScheduling) - .def("generate_scheduling", &SequentialScheduler::generateScheduling, py::arg("verbose")=false) + .def("generate_scheduling", &SequentialScheduler::generateScheduling) .def("get_static_scheduling", &SequentialScheduler::getStaticScheduling, py::arg("step") = 0) ; } diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index bb256b2a431f2be7be1ddff87d98695243a426f8..9ad0ac4f8207fff76bfdf2a4c76399cd6dd10c43 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -40,7 +40,13 @@ void drawProgressBar(double progress, int barWidth, const std::string& additiona fflush(stdout); } -void Aidge::SequentialScheduler::generateScheduling(bool verbose) { +void Aidge::SequentialScheduler::generateScheduling() { + auto schedule = generateBaseScheduling(); + generateEarlyLateScheduling(schedule); + mStaticSchedule.push_back(schedule); +} + +std::vector<Aidge::SequentialScheduler::StaticSchedulingElement> Aidge::SequentialScheduler::generateBaseScheduling() const { // TODO: For loop on the list of node to run // run sequencially every runnable consumers once // TODO: handle memory allocation in scheduler @@ -59,15 +65,15 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { const auto producersConsumers = getConsumers(producers); consumers.insert(producersConsumers.begin(), producersConsumers.end()); - std::map<std::shared_ptr<Node>, std::string> namePtrTable; - if (verbose) namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); + const std::map<std::shared_ptr<Node>, std::string> namePtrTable + = mGraphView->getRankedNodesName("{0} ({1}#{3})"); // Still consumers are consumers that were run by can still consume data. // They must be run AFTER the remaining consumer to ensure a non-greedy // producers-consumers model! std::set<std::shared_ptr<Node>> stillConsumers; - mStaticSchedule.push_back(std::vector<std::shared_ptr<Node>>()); + std::vector<StaticSchedulingElement> schedule; do { // 2) From the current consumers list, check if any prior consumer node @@ -80,33 +86,27 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // If the prior node is of another type, it replaces the initial consumer // in the new priorConsumers list. The initial consumer will become // again a consumer later, by construction. - if (verbose) fmt::print("List of consumers with their priors:\n"); + Log::debug("List of consumers with their priors:"); std::set<std::shared_ptr<Node>> requiredProducers; std::set<std::shared_ptr<Node>> priorConsumers; for (const auto& consumer : consumers) { - if (verbose) { - fmt::print("\t- consumer: "); - fmt::print(fg(fmt::color::orange), namePtrTable[consumer]); - fmt::print("\n"); - } + Log::debug("\t- consumer: {}", fmt::styled(namePtrTable.at(consumer), fg(fmt::color::orange))); const auto& prior = getPriorProducersConsumers(consumer); if (prior.isPrior) { - if (verbose) { - std::vector<std::string> requiredProducersName; - std::transform(prior.requiredProducers.begin(), prior.requiredProducers.end(), - std::back_inserter(requiredProducersName), - [&namePtrTable](auto val){ return namePtrTable[val]; }); - fmt::print("\t\trequired producers: {}\n", requiredProducersName); - - std::vector<std::string> priorConsumersName; - std::transform(prior.priorConsumers.begin(), prior.priorConsumers.end(), - std::back_inserter(priorConsumersName), - [&namePtrTable](auto val){ return namePtrTable[val]; }); - fmt::print("\t\tprior consumers: {}\n", priorConsumersName); - } + std::vector<std::string> requiredProducersName; + std::transform(prior.requiredProducers.begin(), prior.requiredProducers.end(), + std::back_inserter(requiredProducersName), + [&namePtrTable](auto val){ return namePtrTable.at(val); }); + Log::debug("\t\trequired producers: {}", requiredProducersName); + + std::vector<std::string> priorConsumersName; + std::transform(prior.priorConsumers.begin(), prior.priorConsumers.end(), + std::back_inserter(priorConsumersName), + [&namePtrTable](auto val){ return namePtrTable.at(val); }); + Log::debug("\t\tprior consumers: {}", priorConsumersName); requiredProducers.insert(prior.requiredProducers.cbegin(), prior.requiredProducers.cend()); priorConsumers.insert(prior.priorConsumers.cbegin(), prior.priorConsumers.cend()); @@ -125,7 +125,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // Producers are special nodes that generate data on demand. for (const auto& requiredProducer : requiredProducers) { requiredProducer->getOperator()->updateConsummerProducer(); - mStaticSchedule.back().push_back(requiredProducer); + schedule.push_back(StaticSchedulingElement(requiredProducer)); } // 5) Find runnable consumers. @@ -134,32 +134,32 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // runnable because some may depend on the execution of others (when // there is multiple successive priors for example). std::set<std::shared_ptr<Node>> runnableConsumers; - if (verbose) fmt::print("Updated list of consumers:\n"); + Log::debug("Updated list of consumers:"); for (const auto& consumer : consumers) { - if (verbose) { - fmt::print("\t- consumer: "); - fmt::print(fg(fmt::color::orange), namePtrTable[consumer]); - fmt::print("\n\t\tC/R:\t"); - for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { - fmt::print("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), - consumer->getOperator()->getNbRequiredData(inId)); - } - fmt::print("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), - consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1)); - fmt::print("\n\t\tP:\t"); - for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) { - fmt::print("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); - } - fmt::print("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); - fmt::print("\n"); + Log::debug("\t- consumer: {}", fmt::styled(namePtrTable.at(consumer), fg(fmt::color::orange))); + + std::string crLog = "\t\tC/R:\t"; + for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { + crLog += fmt::format("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), + consumer->getOperator()->getNbRequiredData(inId)); } - + crLog += fmt::format("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), + consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1)); + Log::debug("{}", crLog); + + std::string pLog = "\t\tP:\t"; + for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) { + pLog += fmt::format("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); + } + pLog += fmt::format("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); + Log::debug("{}", pLog); + bool isRunnable = true; for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) { if (/*consumer->getOperator()->getNbRequiredData(inputIdx) > 0 && */(consumer->getOperator()->getNbConsumedData(inputIdx) + consumer->getOperator()->getNbRequiredData(inputIdx)) > getNbAvailableData(consumer, inputIdx)) { - if (verbose) fmt::print(" not runnable: C{} + R{} > P{} for input #{}\n", + Log::debug(" not runnable: C{} + R{} > P{} for input #{}", consumer->getOperator()->getNbConsumedData(inputIdx), consumer->getOperator()->getNbRequiredData(inputIdx), getNbAvailableData(consumer, inputIdx), inputIdx); @@ -177,7 +177,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // 5) If not consumer is runnable, it is a stop condition! if (runnableConsumers.empty()) { - if (verbose) fmt::print("********************\n"); + Log::debug("********************"); // No consumer is runnable: some required data is missing for all of // them. There is two possibilities: // - At least one required data source is exhausted, which may be @@ -192,30 +192,31 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // At this point, simultaneously runnable consumers have no data // dependency and could be run in parallel! for (const auto& runnable : runnableConsumers) { - if (verbose) fmt::print("Runnable: {}\n", namePtrTable[runnable]); + Log::debug("Runnable: {}", namePtrTable.at(runnable)); runnable->getOperator()->updateConsummerProducer(); - mStaticSchedule.back().push_back(runnable); + schedule.push_back(StaticSchedulingElement(runnable)); } // 7) Update consumers list - if (verbose) fmt::print("Updating producer and consumer lists...\n"); + Log::debug("Updating producer and consumer lists..."); for (const auto& consumer : runnableConsumers) { - if (verbose) { - fmt::print("\t- consumer: {}\n\t\tC/R:\t", - namePtrTable[consumer]); - for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { - fmt::print("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), - consumer->getOperator()->getNbRequiredData(inId)); - } - fmt::print("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), - consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1)); - fmt::print("\n\t\tP:\t"); - for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) { - fmt::print("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); - } - fmt::print("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); - fmt::print("\n"); + Log::debug("\t- consumer: {}", fmt::styled(namePtrTable.at(consumer), fg(fmt::color::orange))); + + std::string crLog = "\t\tC/R:\t"; + for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) { + crLog += fmt::format("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId), + consumer->getOperator()->getNbRequiredData(inId)); + } + crLog += fmt::format("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1), + consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1)); + Log::debug("{}", crLog); + + std::string pLog = "\t\tP:\t"; + for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) { + pLog += fmt::format("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId)); } + pLog += fmt::format("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1)); + Log::debug("{}", pLog); // 7.1) If the current consumer has still data to consume, it will // be put back in the consumers list once the remaining consumers @@ -224,7 +225,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) { if (consumer->getOperator()->getNbConsumedData(inputIdx) < getNbAvailableData(consumer, inputIdx)) { - if (verbose) fmt::print(" still consumer: C{} < P{} for input #{}\n", + Log::debug(" still consumer: C{} < P{} for input #{}", consumer->getOperator()->getNbConsumedData(inputIdx), getNbAvailableData(consumer, inputIdx), inputIdx); @@ -253,7 +254,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { } /* if (consumer->getOperator()->getNbProducedData(outId) > 0) { - if (verbose) fmt::print(" also producer\n"); + Log::debug(" also producer"); // make sure consumer is also a producer producers.insert(consumer); @@ -267,7 +268,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { consumers.erase(consumer); if (isProducer) { - if (verbose) fmt::print(" also producer\n"); + Log::debug(" also producer"); // make sure consumer is also a producer producers.insert(consumer); @@ -290,79 +291,70 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { stillConsumers.clear(); } - if (verbose) fmt::print("********************\n"); + Log::debug("********************"); } while (!consumers.empty()); - if (verbose) { - if (!consumers.empty()) { - fmt::print("/!\\ Remaining consumers: possible dead-lock\n"); - fmt::print("********************\n"); - } + if (!consumers.empty()) { + Log::warn("Remaining consumers: possible dead-lock"); } -} -std::vector<Aidge::SequentialScheduler::StaticSchedulingElement> -Aidge::SequentialScheduler::generateEarlyLateScheduling() const { - std::vector<StaticSchedulingElement> scheduling; + return schedule; +} +void Aidge::SequentialScheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingElement>& schedule) const { size_t latest = 0; - for (size_t elt = 0; elt < mStaticSchedule.at(0).size(); ++elt) { - const auto node = mStaticSchedule.at(0)[elt]; - const auto itNode = std::find_if(scheduling.rbegin(), scheduling.rend(), [node](const auto& v) { return (v.node == node); }); + // Calculate early (logical) start + for (size_t elt = 0; elt < schedule.size(); ++elt) { + const auto node = schedule[elt].node; + const auto itNode = std::find_if(schedule.rend() - elt, schedule.rend(), + [node](const auto& v) { return (v.node == node); }); - // Find early step: node can be run after the latest parent was run - // also, node must be run after latest node run! + // Node can be run the earliest just after it was run the last time! size_t early = 0; - if (itNode != scheduling.rend()) { + if (itNode != schedule.rend()) { early = (*itNode).early + 1; } + // Node can be run the earliest just after its latest parent was run for (const auto parent : node->getParents()) { // Find parent node latest scheduled position - const auto it = std::find(mStaticSchedule.at(0).rend() - elt, mStaticSchedule.at(0).rend(), parent); - if (it != mStaticSchedule.at(0).rend()) { - const size_t step = std::distance(mStaticSchedule.at(0).begin(), it.base()) - 1; - early = std::max(early, scheduling[step].early + 1); - } - } -/* - // Update late step for parents - for (const auto parent : node->getParents()) { - const auto it = std::find(mStaticSchedule.at(0).rend() - elt, mStaticSchedule.at(0).rend(), parent); - if (it != mStaticSchedule.at(0).rend()) { - const size_t step = std::distance(mStaticSchedule.at(0).begin(), it.base()) - 1; - scheduling[step].late = std::min(scheduling[step].late, early - 1); - latest = std::max(latest, scheduling[step].late); + const auto it = std::find_if(schedule.rend() - elt, schedule.rend(), + [parent](const auto& v) { return (v.node == parent); }); + if (it != schedule.rend()) { + const size_t step = std::distance(schedule.begin(), it.base()) - 1; + early = std::max(early, schedule[step].early + 1); } } -*/ + latest = std::max(latest, early); - size_t late = static_cast<size_t>(-1); - scheduling.push_back(StaticSchedulingElement(node, early, late)); + schedule[elt].early = early; } - for (size_t elt = mStaticSchedule.at(0).size(); elt-- != 0; ) { - const auto node = mStaticSchedule.at(0)[elt]; - const auto itNode = std::find_if(scheduling.begin() + elt + 1, scheduling.end(), [node](const auto& v) { return (v.node == node); }); + // Calculate late (logical) start + for (size_t elt = schedule.size(); elt-- != 0; ) { + const auto node = schedule[elt].node; + const auto itNode = std::find_if(schedule.begin() + elt + 1, schedule.end(), + [node](const auto& v) { return (v.node == node); }); + // Node can be run the latest just before it is run the next time! size_t late = latest; - if (itNode != scheduling.end()) { + if (itNode != schedule.end()) { late = (*itNode).late - 1; } + // Node can be run the latest just before its earliest child is run for (const auto child : node->getChildren()) { // Find child node earliest scheduled position - const auto it = std::find(mStaticSchedule.at(0).begin() + elt + 1, mStaticSchedule.at(0).end(), child); - if (it != mStaticSchedule.at(0).end()) { - const size_t step = std::distance(mStaticSchedule.at(0).begin(), it); - late = std::min(late, scheduling[step].late - 1); + const auto it = std::find_if(schedule.begin() + elt + 1, schedule.end(), + [child](const auto& v) { return (v.node == child); }); + if (it != schedule.end()) { + const size_t step = std::distance(schedule.begin(), it); + late = std::min(late, schedule[step].late - 1); } } - scheduling[elt].late = late; + schedule[elt].late = late; } - - return scheduling; } void Aidge::SequentialScheduler::resetScheduling() { @@ -381,8 +373,8 @@ void Aidge::SequentialScheduler::resetScheduling() { Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducers, bool wrapAroundBuffer) const { MemoryManager memManager; - for (const auto& shedule : mStaticSchedule) { - for (const auto& node : shedule) { + for (size_t step = 0; step < mStaticSchedule.size(); ++step) { + for (const auto& node : getStaticScheduling(step)) { if (!incProducers && node->type() == Producer_Op::Type) { memManager.releaseDependencies(node); continue; @@ -512,13 +504,13 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::ve // If scheduling was already generated (in one or several steps, i.e. one or // several successive call to generateScheduling()), do not generate it twice if (mStaticSchedule.empty()) { - this->generateScheduling(verbose); + this->generateScheduling(); } const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); size_t cpt = 0; - for (const auto& runnable : mStaticSchedule.at(mStaticScheduleStep)) { + for (const auto& runnable : getStaticScheduling(mStaticScheduleStep)) { if (verbose) fmt::print("run: {}\n", namePtrTable.at(runnable)); else @@ -569,7 +561,7 @@ void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileNa fmt::print(fp.get(), "\n"); } -void Aidge::SequentialScheduler::saveStaticSchedulingDiagram(const std::string& fileName, const std::vector<StaticSchedulingElement>& scheduling) const { +void Aidge::SequentialScheduler::saveStaticSchedulingDiagram(const std::string& fileName) const { auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose); if (!fp) { @@ -579,23 +571,32 @@ void Aidge::SequentialScheduler::saveStaticSchedulingDiagram(const std::string& fmt::print(fp.get(), "gantt\ndateFormat x\naxisFormat %Q\n\n"); - if (!scheduling.empty()) { + if (!mStaticSchedule.empty()) { const std::map<std::shared_ptr<Node>, std::string> namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); - for (const auto& element : scheduling) { - auto name = namePtrTable.at(element.node); - // Mermaid does not allow : character in task title - std::replace(name.begin(), name.end(), ':', '_'); + for (const auto& schedule : mStaticSchedule) { + for (const auto& element : schedule) { + auto name = namePtrTable.at(element.node); + // Mermaid does not allow : character in task title + std::replace(name.begin(), name.end(), ':', '_'); - fmt::print(fp.get(), "{} :{}, {}\n", - name, element.early, element.late); + fmt::print(fp.get(), "{} :{}, {}\n", + name, element.early, element.late); + } } } fmt::print(fp.get(), "\n"); } +std::vector<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getStaticScheduling(size_t step) const { + const auto& staticSchedule = mStaticSchedule.at(step); + std::vector<std::shared_ptr<Node>> schedule; + std::transform(staticSchedule.begin(), staticSchedule.end(), std::back_inserter(schedule), [](const auto& v) { return v.node; }); + return schedule; +} + std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers( const std::set<std::shared_ptr<Node>>& producers) const { std::set<std::shared_ptr<Node>> consumers; diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index 7e28f1fadc56855d266c1e8547261f5903f8c724..563fc19907e8800bbe6dec9e23242c6bb2c90c43 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -58,7 +58,7 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { g1->forwardDims(); auto scheduler = SequentialScheduler(g1); - scheduler.generateScheduling(true); + scheduler.generateScheduling(); const auto sch = scheduler.getStaticScheduling(); const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})");