diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 747785bf886889aed273c944904ddbb6198c4968..c0cccd5afb4e0cbfe8104ebfe3240d6918a7ce77 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -30,6 +30,18 @@ class GraphView; class SequentialScheduler { private: + struct StaticSchedulingElement { + StaticSchedulingElement( + std::shared_ptr<Node> node_, + size_t early_, + size_t late_) + : node(node_), early(early_), late(late_) {} + + std::shared_ptr<Node> node; + size_t early; + size_t late; + }; + struct SchedulingElement { SchedulingElement( std::shared_ptr<Node> node_, @@ -58,6 +70,7 @@ public: ~SequentialScheduler() = default; void generateScheduling(bool verbose = false); + std::vector<StaticSchedulingElement> generateEarlyLateScheduling() const; void resetScheduling(); /** @@ -84,6 +97,8 @@ public: * @param fileName Name of the generated file. */ void saveSchedulingDiagram(const std::string& fileName) const; + + void saveStaticSchedulingDiagram(const std::string& fileName, const std::vector<StaticSchedulingElement>& scheduling) const; /** * @brief Return a vector of Node ordered by the order they are called by the scheduler diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 2975538bc3271f4dbf6faea920be3a05452a0859..bb256b2a431f2be7be1ddff87d98695243a426f8 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -301,6 +301,70 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { } } +std::vector<Aidge::SequentialScheduler::StaticSchedulingElement> +Aidge::SequentialScheduler::generateEarlyLateScheduling() const { + std::vector<StaticSchedulingElement> scheduling; + + size_t latest = 0; + for (size_t elt = 0; elt < mStaticSchedule.at(0).size(); ++elt) { + const auto node = mStaticSchedule.at(0)[elt]; + const auto itNode = std::find_if(scheduling.rbegin(), scheduling.rend(), [node](const auto& v) { return (v.node == node); }); + + // Find early step: node can be run after the latest parent was run + // also, node must be run after latest node run! + size_t early = 0; + if (itNode != scheduling.rend()) { + early = (*itNode).early + 1; + } + + for (const auto parent : node->getParents()) { + // Find parent node latest scheduled position + const auto it = std::find(mStaticSchedule.at(0).rend() - elt, mStaticSchedule.at(0).rend(), parent); + if (it != mStaticSchedule.at(0).rend()) { + const size_t step = std::distance(mStaticSchedule.at(0).begin(), it.base()) - 1; + early = std::max(early, scheduling[step].early + 1); + } + } +/* + // Update late step for parents + for (const auto parent : node->getParents()) { + const auto it = std::find(mStaticSchedule.at(0).rend() - elt, mStaticSchedule.at(0).rend(), parent); + if (it != mStaticSchedule.at(0).rend()) { + const size_t step = std::distance(mStaticSchedule.at(0).begin(), it.base()) - 1; + scheduling[step].late = std::min(scheduling[step].late, early - 1); + latest = std::max(latest, scheduling[step].late); + } + } +*/ + latest = std::max(latest, early); + size_t late = static_cast<size_t>(-1); + scheduling.push_back(StaticSchedulingElement(node, early, late)); + } + + for (size_t elt = mStaticSchedule.at(0).size(); elt-- != 0; ) { + const auto node = mStaticSchedule.at(0)[elt]; + const auto itNode = std::find_if(scheduling.begin() + elt + 1, scheduling.end(), [node](const auto& v) { return (v.node == node); }); + + size_t late = latest; + if (itNode != scheduling.end()) { + late = (*itNode).late - 1; + } + + for (const auto child : node->getChildren()) { + // Find child node earliest scheduled position + const auto it = std::find(mStaticSchedule.at(0).begin() + elt + 1, mStaticSchedule.at(0).end(), child); + if (it != mStaticSchedule.at(0).end()) { + const size_t step = std::distance(mStaticSchedule.at(0).begin(), it); + late = std::min(late, scheduling[step].late - 1); + } + } + + scheduling[elt].late = late; + } + + return scheduling; +} + void Aidge::SequentialScheduler::resetScheduling() { for (auto node : mGraphView->getNodes()) { node->getOperator()->resetConsummerProducer(); @@ -505,6 +569,33 @@ void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileNa fmt::print(fp.get(), "\n"); } +void Aidge::SequentialScheduler::saveStaticSchedulingDiagram(const std::string& fileName, const std::vector<StaticSchedulingElement>& scheduling) const { + auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose); + + if (!fp) { + AIDGE_THROW_OR_ABORT(std::runtime_error, + "Could not create scheduling diagram log file: {}", fileName + ".mmd"); + } + + fmt::print(fp.get(), "gantt\ndateFormat x\naxisFormat %Q\n\n"); + + if (!scheduling.empty()) { + const std::map<std::shared_ptr<Node>, std::string> namePtrTable + = mGraphView->getRankedNodesName("{0} ({1}#{3})"); + + for (const auto& element : scheduling) { + auto name = namePtrTable.at(element.node); + // Mermaid does not allow : character in task title + std::replace(name.begin(), name.end(), ':', '_'); + + fmt::print(fp.get(), "{} :{}, {}\n", + name, element.early, element.late); + } + } + + fmt::print(fp.get(), "\n"); +} + std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers( const std::set<std::shared_ptr<Node>>& producers) const { std::set<std::shared_ptr<Node>> consumers;