diff --git a/include/aidge/scheduler/ProdConso.hpp b/include/aidge/scheduler/ProdConso.hpp index f30e00afa082658fce1eca8b4cb885e1b23fb7c7..bc42cb36c309952dbfc872d0722e1111aae6d3b7 100644 --- a/include/aidge/scheduler/ProdConso.hpp +++ b/include/aidge/scheduler/ProdConso.hpp @@ -34,6 +34,10 @@ public: return std::make_unique<ProdConso>(op, true); } + const Operator& getOperator() const noexcept { + return mOp; + } + /** * @brief Minimum amount of data from a specific input required by the * implementation to be run. diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 51f62ed1b7412abed4e0b850183fb101f208a69e..db9b903cc9f90ef3d84252f34867b23cfe13d237 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -187,6 +187,7 @@ public: * @param fileName Name of the file to save the diagram (without extension). */ void saveStaticSchedulingDiagram(const std::string& fileName) const; + void saveFactorizedStaticSchedulingDiagram(const std::string& fileName, size_t minRepeat = 2) const; /** * @brief Save in a Mermaid file the order of layers execution. @@ -233,6 +234,18 @@ protected: */ void generateEarlyLateScheduling(std::vector<StaticSchedulingElement*>& schedule) const; + /** + * @brief Get the factorized scheduling, by identifying repetitive sequences + * in the scheduling. + * + * @param schedule Vector of shared pointers to StaticSchedulingElements to be processed + * @param size_t Minimum number repetitions to factorize the sequence + * @return Vector containing the repetitive sequences, in order. The second + * element of the pair is the number of repetitions. + */ + std::vector<std::pair<std::vector<StaticSchedulingElement*>, size_t>> + getFactorizedScheduling(const std::vector<StaticSchedulingElement*>& schedule, size_t minRepeat = 2) const; + private: /** * @brief Summarize the consumer state of a node for debugging purposes. diff --git a/python_binding/data/pybind_Elts.cpp b/python_binding/data/pybind_Elts.cpp new file mode 100644 index 0000000000000000000000000000000000000000..59a8211e26c59cbcc4ec3e8456933403ba6c25e7 --- /dev/null +++ b/python_binding/data/pybind_Elts.cpp @@ -0,0 +1,85 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <algorithm> // std::transform +#include <cctype> // std::tolower +#include <string> // std::string +#include <vector> + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> +#include <pybind11/operators.h> + +#include "aidge/data/Elts.hpp" + +namespace py = pybind11; +namespace Aidge { + +template <class T> +void bindEnum(py::module& m, const std::string& name) { + // Define enumeration names for python as lowercase type name + // This defined enum names compatible with basic numpy type + // name such as: float32, flot64, [u]int32, [u]int64, ... + auto python_enum_name = [](const T& type) { + auto str_lower = [](std::string& str) { + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c){ + return std::tolower(c); + }); + }; + auto type_name = std::string(Aidge::format_as(type)); + str_lower(type_name); + return type_name; + }; + + // Auto generate enumeration names from lowercase type strings + std::vector<std::string> enum_names; + for (auto type_str : EnumStrings<T>::data) { + auto type = static_cast<T>(enum_names.size()); + auto enum_name = python_enum_name(type); + enum_names.push_back(enum_name); + } + + // Define python side enumeration aidge_core.type + auto e_type = py::enum_<T>(m, name.c_str()); + + // Add enum value for each enum name + for (std::size_t idx = 0; idx < enum_names.size(); idx++) { + e_type.value(enum_names[idx].c_str(), static_cast<T>(idx)); + } + + // Define str() to return the bare enum name value, it allows + // to compare directly for instance str(tensor.type()) + // with str(nparray.type) + e_type.def("__str__", [enum_names](const T& type) { + return enum_names[static_cast<int>(type)]; + }, py::prepend()); +} + +void init_Elts(py::module& m) { + bindEnum<Elts_t::EltType>(m, "EltType"); + m.def("format_as", (const char* (*)(Elts_t::EltType)) &format_as, py::arg("elt")); + + py::class_<Elts_t, std::shared_ptr<Elts_t>>( + m, "Elts_t", py::dynamic_attr()) + .def_static("none_elts", &Elts_t::NoneElts) + .def_static("data_elts", &Elts_t::DataElts, py::arg("data"), py::arg("token") = 1) + .def_static("token_elts", &Elts_t::TokenElts, py::arg("token")) + .def_readwrite("data", &Elts_t::data) + .def_readwrite("token", &Elts_t::token) + .def_readwrite("type", &Elts_t::type) + .def(py::self + py::self) + .def(py::self += py::self) + .def(py::self < py::self) + .def(py::self > py::self); +} + +} // namespace Aidge diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index 1f35373f3d5105d8815e84949dd4d92c742a296c..cc6f0bf2502027fea467b9db39561769fcebbd2b 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -21,6 +21,7 @@ void init_Random(py::module&); void init_Data(py::module&); void init_DataFormat(py::module&); void init_DataType(py::module&); +void init_Elts(py::module&); void init_Database(py::module&); void init_DataProvider(py::module&); void init_Interpolation(py::module&); @@ -112,6 +113,7 @@ void init_Aidge(py::module& m) { init_Data(m); init_DataFormat(m); init_DataType(m); + init_Elts(m); init_Database(m); init_DataProvider(m); init_Interpolation(m); diff --git a/python_binding/scheduler/pybind_ProdConso.cpp b/python_binding/scheduler/pybind_ProdConso.cpp index abd6d5379178916b5842095d50a1de2155345b6f..547e2258dd353178e5a78e9701cc032e15cafd8b 100644 --- a/python_binding/scheduler/pybind_ProdConso.cpp +++ b/python_binding/scheduler/pybind_ProdConso.cpp @@ -104,6 +104,7 @@ void init_ProdConso(py::module& m){ .def(py::init<const Operator&, bool>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>(), py::keep_alive<1,3>()) .def_static("default_model", &ProdConso::defaultModel) .def_static("in_place_model", &ProdConso::inPlaceModel) + .def("get_operator", &ProdConso::getOperator) .def("get_nb_required_data", &ProdConso::getNbRequiredData) .def("get_nb_required_protected", &ProdConso::getNbRequiredProtected) .def("get_required_memory", &ProdConso::getRequiredMemory) diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp index d4cd7da44d148aef669c893405cff3101b6090b5..1a3a4b6b2f552b05ae33fea0ecff43c4f1ec689f 100644 --- a/python_binding/scheduler/pybind_Scheduler.cpp +++ b/python_binding/scheduler/pybind_Scheduler.cpp @@ -32,6 +32,7 @@ void init_Scheduler(py::module& m){ .def("graph_view", &Scheduler::graphView) .def("save_scheduling_diagram", &Scheduler::saveSchedulingDiagram, py::arg("file_name")) .def("save_static_scheduling_diagram", &Scheduler::saveStaticSchedulingDiagram, py::arg("file_name")) + .def("save_factorized_static_scheduling_diagram", &Scheduler::saveFactorizedStaticSchedulingDiagram, py::arg("file_name"), py::arg("min_repeat") = 2) .def("resetScheduling", &Scheduler::resetScheduling) .def("generate_scheduling", &Scheduler::generateScheduling) .def("get_static_scheduling", &Scheduler::getStaticScheduling, py::arg("step") = 0, py::arg("sorting") = Scheduler::EarlyLateSort::Default) diff --git a/src/recipes/ConstantFolding.cpp b/src/recipes/ConstantFolding.cpp index 40b0bda766ab243805349b13e93391c5a60df63a..613393756f7f6b7104118ea97593e3130055ceeb 100644 --- a/src/recipes/ConstantFolding.cpp +++ b/src/recipes/ConstantFolding.cpp @@ -36,6 +36,7 @@ void Aidge::constantFolding(std::shared_ptr<GraphView> graph) { for (const auto& node : candidates) { bool foldable = true; auto replaceGraph = std::make_shared<GraphView>(); + size_t i = 0; for (const auto& input : node->inputs()) { if (input.first) { if (input.first->type() != Producer_Op::Type) { @@ -53,6 +54,13 @@ void Aidge::constantFolding(std::shared_ptr<GraphView> graph) { replaceGraph->add(input.first, false); } + else if (node->inputCategory(i) != InputCategory::OptionalData + && node->inputCategory(i) != InputCategory::OptionalParam) + { + foldable = false; + break; + } + ++i; } if (foldable) { diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 396e90c091fb62b406ecd336d7e2d29e82c74b33..bbbc3d8076cc1ff5fe25e9f60e4827b94ef0de4f 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -427,6 +427,83 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE } } +std::vector<std::pair<std::vector<Aidge::Scheduler::StaticSchedulingElement*>, size_t>> +Aidge::Scheduler::getFactorizedScheduling(const std::vector<StaticSchedulingElement*>& schedule, size_t minRepeat) const +{ + std::vector<std::pair<std::vector<StaticSchedulingElement*>, size_t>> sequences; + size_t offset = 0; + + for (size_t i = 0; i < schedule.size(); ) { + // Find all the possible repetitive sequences starting from this element + std::vector<StaticSchedulingElement*> seq; + std::vector<std::pair<std::vector<StaticSchedulingElement*>, size_t>> longuestSeq; + std::vector<size_t> longuestSeqOffset; + + for (size_t k = i; k < schedule.size(); ++k) { + // For each sequence length, starting from 1... + seq.push_back(new StaticSchedulingElement( + schedule[k]->node, + schedule[k]->early - offset, + schedule[k]->late - offset)); + + size_t start = k + 1; + size_t nbRepeats = 1; + bool repeat = true; + const auto seqOffset = (start < schedule.size()) ? schedule[start]->early - offset - seq[0]->early : 0; + + do { + // Count the number of consecutive sequences (repetitions) + for (size_t r = 0; r < seq.size(); ++r) { + if (start + r >= schedule.size() + || schedule[start + r]->node != seq[r]->node + || schedule[start + r]->early - offset != seq[r]->early + seqOffset * nbRepeats + || schedule[start + r]->late - offset != seq[r]->late + seqOffset * nbRepeats) + { + repeat = false; + break; + } + } + + if (repeat) { + start += seq.size(); + ++nbRepeats; + } + } + while (repeat); + + if (nbRepeats >= minRepeat) { + // If repetitions exist for this sequence length, add it to the list + longuestSeq.push_back(std::make_pair(seq, nbRepeats)); + longuestSeqOffset.push_back(seqOffset); + } + else if (k == i) { + // Ensure that at least the current element is in the list if no + // repetition is found + longuestSeq.push_back(std::make_pair(seq, 1)); + longuestSeqOffset.push_back(0); + } + } + + // Select the one with the best factorization + // i.e. which maximize the product sequence length * number of sequences + size_t maxS = 0; + size_t maxFactorization = 0; + for (size_t s = 0; s < longuestSeq.size(); ++s) { + const auto factor = longuestSeq[s].first.size() * longuestSeq[s].second; + if (factor > maxFactorization) { + maxFactorization = factor; + maxS = s; + } + } + + sequences.push_back(longuestSeq[maxS]); + i += longuestSeq[maxS].first.size() * longuestSeq[maxS].second; + offset += longuestSeqOffset[maxS] * (longuestSeq[maxS].second - 1); + } + + return sequences; +} + void Aidge::Scheduler::resetScheduling() { for (auto node : mGraphView->getNodes()) { node->getOperator()->resetConsummerProducer(); @@ -869,8 +946,65 @@ void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName) // Mermaid does not allow : character in task title std::replace(name.begin(), name.end(), ':', '_'); - fmt::print(fp.get(), "{} :{}, {}\n", - name, element->early, element->late); + if (element->early == element->late) { + fmt::print(fp.get(), "{} :milestone, {}, {}\n", + name, element->early, element->late); + } + else { + fmt::print(fp.get(), "{} :{}, {}\n", + name, element->early, element->late); + } + } + } + } + + fmt::print(fp.get(), "\n"); +} + +void Aidge::Scheduler::saveFactorizedStaticSchedulingDiagram(const std::string& fileName, size_t minRepeat) const { + auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose); + + if (!fp) { + AIDGE_THROW_OR_ABORT(std::runtime_error, + "Could not create scheduling diagram log file: {}", fileName + ".mmd"); + } + + fmt::print(fp.get(), "gantt\ndateFormat x\naxisFormat %Q\n\n"); + + if (!mStaticSchedule.empty()) { + const std::map<std::shared_ptr<Node>, std::string> namePtrTable + = mGraphView->getRankedNodesName("{0} ({1}#{3})"); + + for (const auto& schedule : mStaticSchedule) { + const auto factorizedSchedule = getFactorizedScheduling(schedule, minRepeat); + + size_t seq = 0; + for (const auto& sequence : factorizedSchedule) { + if (sequence.second > 1) { + fmt::print(fp.get(), "section seq#{} (x{})\n", seq, sequence.second); + } + else { + fmt::print(fp.get(), "section seq#{}\n", seq); + } + + for (const auto& element : sequence.first) { + auto name = namePtrTable.at(element->node); + // Mermaid does not allow : character in task title + std::replace(name.begin(), name.end(), ':', '_'); + std::string tag = ":"; + + if (element->early == element->late) { + tag += "milestone, "; + } + + if (sequence.second > 1) { + tag += "active, "; + } + + fmt::print(fp.get(), "{} {}{}, {}\n", + name, tag, element->early, element->late); + } + ++seq; } } }