diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 75fc1577831563a485202947da9e7f1a030f5deb..db9b903cc9f90ef3d84252f34867b23cfe13d237 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -187,7 +187,7 @@ public: * @param fileName Name of the file to save the diagram (without extension). */ void saveStaticSchedulingDiagram(const std::string& fileName) const; - void saveFactorizedStaticSchedulingDiagram(const std::string& fileName) const; + void saveFactorizedStaticSchedulingDiagram(const std::string& fileName, size_t minRepeat = 2) const; /** * @brief Save in a Mermaid file the order of layers execution. @@ -239,11 +239,12 @@ protected: * in the scheduling. * * @param schedule Vector of shared pointers to StaticSchedulingElements to be processed + * @param size_t Minimum number repetitions to factorize the sequence * @return Vector containing the repetitive sequences, in order. The second * element of the pair is the number of repetitions. */ std::vector<std::pair<std::vector<StaticSchedulingElement*>, size_t>> - getFactorizedScheduling(const std::vector<StaticSchedulingElement*>& schedule) const; + getFactorizedScheduling(const std::vector<StaticSchedulingElement*>& schedule, size_t minRepeat = 2) const; private: /** diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp index 10dc66ae86264a539266a807449d190fa0b59519..1a3a4b6b2f552b05ae33fea0ecff43c4f1ec689f 100644 --- a/python_binding/scheduler/pybind_Scheduler.cpp +++ b/python_binding/scheduler/pybind_Scheduler.cpp @@ -32,7 +32,7 @@ void init_Scheduler(py::module& m){ .def("graph_view", &Scheduler::graphView) .def("save_scheduling_diagram", &Scheduler::saveSchedulingDiagram, py::arg("file_name")) .def("save_static_scheduling_diagram", &Scheduler::saveStaticSchedulingDiagram, py::arg("file_name")) - .def("save_factorized_static_scheduling_diagram", &Scheduler::saveFactorizedStaticSchedulingDiagram, py::arg("file_name")) + .def("save_factorized_static_scheduling_diagram", &Scheduler::saveFactorizedStaticSchedulingDiagram, py::arg("file_name"), py::arg("min_repeat") = 2) .def("resetScheduling", &Scheduler::resetScheduling) .def("generate_scheduling", &Scheduler::generateScheduling) .def("get_static_scheduling", &Scheduler::getStaticScheduling, py::arg("step") = 0, py::arg("sorting") = Scheduler::EarlyLateSort::Default) diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 26bb357e78c99a2119c4a35d28d734f5aa8e07d2..bbbc3d8076cc1ff5fe25e9f60e4827b94ef0de4f 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -428,24 +428,19 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE } std::vector<std::pair<std::vector<Aidge::Scheduler::StaticSchedulingElement*>, size_t>> -Aidge::Scheduler::getFactorizedScheduling(const std::vector<StaticSchedulingElement*>& schedule) const +Aidge::Scheduler::getFactorizedScheduling(const std::vector<StaticSchedulingElement*>& schedule, size_t minRepeat) const { std::vector<std::pair<std::vector<StaticSchedulingElement*>, size_t>> sequences; size_t offset = 0; for (size_t i = 0; i < schedule.size(); ) { - std::vector<StaticSchedulingElement*> seq; - seq.push_back(new StaticSchedulingElement( - schedule[i]->node, - schedule[i]->early - offset, - schedule[i]->late - offset)); - // Find all the possible repetitive sequences starting from this element - std::vector<std::pair<std::vector<StaticSchedulingElement*>, size_t>> longuestSeq = {std::make_pair(seq, 1)}; - std::vector<size_t> longuestSeqOffset = {0}; + std::vector<StaticSchedulingElement*> seq; + std::vector<std::pair<std::vector<StaticSchedulingElement*>, size_t>> longuestSeq; + std::vector<size_t> longuestSeqOffset; - for (size_t k = i + 1; k < schedule.size() - 1; ++k) { - // For each sequence length, starting from 2... + for (size_t k = i; k < schedule.size(); ++k) { + // For each sequence length, starting from 1... seq.push_back(new StaticSchedulingElement( schedule[k]->node, schedule[k]->early - offset, @@ -454,7 +449,7 @@ Aidge::Scheduler::getFactorizedScheduling(const std::vector<StaticSchedulingElem size_t start = k + 1; size_t nbRepeats = 1; bool repeat = true; - const auto seqOffset = schedule[start]->early - offset - seq[0]->early; + const auto seqOffset = (start < schedule.size()) ? schedule[start]->early - offset - seq[0]->early : 0; do { // Count the number of consecutive sequences (repetitions) @@ -476,11 +471,17 @@ Aidge::Scheduler::getFactorizedScheduling(const std::vector<StaticSchedulingElem } while (repeat); - if (nbRepeats > 1) { + if (nbRepeats >= minRepeat) { // If repetitions exist for this sequence length, add it to the list longuestSeq.push_back(std::make_pair(seq, nbRepeats)); longuestSeqOffset.push_back(seqOffset); } + else if (k == i) { + // Ensure that at least the current element is in the list if no + // repetition is found + longuestSeq.push_back(std::make_pair(seq, 1)); + longuestSeqOffset.push_back(0); + } } // Select the one with the best factorization @@ -960,7 +961,7 @@ void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName) fmt::print(fp.get(), "\n"); } -void Aidge::Scheduler::saveFactorizedStaticSchedulingDiagram(const std::string& fileName) const { +void Aidge::Scheduler::saveFactorizedStaticSchedulingDiagram(const std::string& fileName, size_t minRepeat) const { auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose); if (!fp) { @@ -975,7 +976,7 @@ void Aidge::Scheduler::saveFactorizedStaticSchedulingDiagram(const std::string& = mGraphView->getRankedNodesName("{0} ({1}#{3})"); for (const auto& schedule : mStaticSchedule) { - const auto factorizedSchedule = getFactorizedScheduling(schedule); + const auto factorizedSchedule = getFactorizedScheduling(schedule, minRepeat); size_t seq = 0; for (const auto& sequence : factorizedSchedule) { @@ -990,15 +991,18 @@ void Aidge::Scheduler::saveFactorizedStaticSchedulingDiagram(const std::string& auto name = namePtrTable.at(element->node); // Mermaid does not allow : character in task title std::replace(name.begin(), name.end(), ':', '_'); + std::string tag = ":"; if (element->early == element->late) { - fmt::print(fp.get(), "{} :milestone, {}, {}\n", - name, element->early, element->late); + tag += "milestone, "; } - else { - fmt::print(fp.get(), "{} :{}, {}\n", - name, element->early, element->late); + + if (sequence.second > 1) { + tag += "active, "; } + + fmt::print(fp.get(), "{} {}{}, {}\n", + name, tag, element->early, element->late); } ++seq; }