From a2705ba1fa60b27f23ffd11b53240bafa918d1ad Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Tue, 4 Mar 2025 15:39:31 +0100 Subject: [PATCH] Added new scheduling policies --- include/aidge/graph/Node.hpp | 15 ++++ include/aidge/graph/Testing.hpp | 2 + include/aidge/operator/Operator.hpp | 17 ++++ include/aidge/scheduler/Scheduler.hpp | 9 ++- include/aidge/utils/Attributes.hpp | 16 ++++ include/aidge/utils/DynamicAttributes.hpp | 67 +++++++++++++--- src/graph/GraphView.cpp | 12 ++- src/graph/Testing.cpp | 8 ++ src/scheduler/Scheduler.cpp | 95 +++++++++++++++++++++++ 9 files changed, 225 insertions(+), 16 deletions(-) diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index cd2ca38df..b270d1474 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -542,5 +542,20 @@ private: } // namespace Aidge +template <> +struct fmt::formatter<Aidge::Node> : formatter<const char*> { + template <typename FormatContext> + auto format(const Aidge::Node& node, FormatContext& ctx) const { + return fmt::format_to(ctx.out(), "n.{}", node.getOperator()); + } +}; + +template <> +struct fmt::formatter<std::shared_ptr<Aidge::Node>> : formatter<const char*> { + template <typename FormatContext> + auto format(const std::shared_ptr<Aidge::Node>& node, FormatContext& ctx) const { + return fmt::format_to(ctx.out(), "n.{}", node->getOperator()); + } +}; #endif /* AIDGE_CORE_GRAPH_NODE_H_ */ diff --git a/include/aidge/graph/Testing.hpp b/include/aidge/graph/Testing.hpp index ecacdf662..9ead0b406 100644 --- a/include/aidge/graph/Testing.hpp +++ b/include/aidge/graph/Testing.hpp @@ -58,6 +58,8 @@ std::string nodePtrToType(NodePtr node); std::string nodePtrToName(NodePtr node); std::set<std::string> nodePtrTo(const std::set<NodePtr>& nodes, std::string(*nodeTo)(NodePtr) = nodePtrToType); +std::vector<std::string> nodePtrTo(const std::vector<NodePtr>& nodes, + std::string(*nodeTo)(NodePtr) = nodePtrToType); std::vector<std::pair<std::string, IOIndex_t>> nodePtrTo( const std::vector<std::pair<NodePtr, IOIndex_t>>& nodes, std::string(*nodeTo)(NodePtr) = nodePtrToType); diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 5a12cfea2..3797546b9 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -377,4 +377,21 @@ public: } // namespace Aidge + +template <> +struct fmt::formatter<Aidge::Operator> : formatter<const char*> { + template <typename FormatContext> + auto format(const Aidge::Operator& op, FormatContext& ctx) const { + return fmt::format_to(ctx.out(), "op.{}", op.type()); + } +}; + +template <> +struct fmt::formatter<std::shared_ptr<Aidge::Operator>> : formatter<const char*> { + template <typename FormatContext> + auto format(const std::shared_ptr<Aidge::Operator>& op, FormatContext& ctx) const { + return fmt::format_to(ctx.out(), "op.{}", op->type()); + } +}; + #endif // AIDGE_CORE_OPERATOR_OPERATOR_H_ diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 7c309783d..1c269cbb4 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -93,7 +93,9 @@ public: enum class SchedulingPolicy { Default, AsSoonAsPossible, - AsLateAsPossible + AsLateAsPossible, + ShortestBranchFirst, + LonguestBranchFirst }; /** @@ -135,6 +137,11 @@ public: */ void tagConditionalNodes() const; + /** + * @brief Add schedule.branch attribute to nodes. + */ + void tagForkBranches() const; + /** * @brief Get the static scheduling (after generate scheduling). * @return Vector of StaticSchedulingElement pointers. diff --git a/include/aidge/utils/Attributes.hpp b/include/aidge/utils/Attributes.hpp index e25485fe0..123e36a59 100644 --- a/include/aidge/utils/Attributes.hpp +++ b/include/aidge/utils/Attributes.hpp @@ -12,10 +12,13 @@ #ifndef AIDGE_CORE_UTILS_ATTRIBUTES_H_ #define AIDGE_CORE_UTILS_ATTRIBUTES_H_ +#include <memory> #include <string> #include <set> #include <map> +#include <fmt/format.h> + #include "aidge/utils/future_std/any.hpp" #ifdef PYBIND @@ -84,4 +87,17 @@ public: }; } +template<> +struct fmt::formatter<std::shared_ptr<Aidge::Attributes>> { + template<typename ParseContext> + inline constexpr auto parse(ParseContext& ctx) -> decltype(ctx.begin()) { + return ctx.begin(); + } + + template<typename FormatContext> + inline auto format(const std::shared_ptr<Aidge::Attributes>& attrs, FormatContext& ctx) const { + return fmt::format_to(ctx.out(), "{}", *attrs); + } +}; + #endif /* AIDGE_CORE_UTILS_ATTRIBUTES_H_ */ diff --git a/include/aidge/utils/DynamicAttributes.hpp b/include/aidge/utils/DynamicAttributes.hpp index 633ce40d9..be3bc1a97 100644 --- a/include/aidge/utils/DynamicAttributes.hpp +++ b/include/aidge/utils/DynamicAttributes.hpp @@ -19,6 +19,7 @@ #include <cassert> #include <string> #include <typeindex> +#include <regex> #include "aidge/utils/future_std/any.hpp" #include "aidge/utils/Attributes.hpp" @@ -331,6 +332,7 @@ public: #ifdef PYBIND virtual py::object cast(const future_std::any& attr) const = 0; #endif + virtual std::string str(const future_std::any& attr) const = 0; virtual bool compare(const future_std::any&, const future_std::any&) const = 0; virtual size_t hash(const future_std::any&) const = 0; virtual ~AnyUtils_() = default; @@ -344,6 +346,10 @@ public: } #endif + std::string str(const future_std::any& attr) const override final { + return fmt::format("{}", future_std::any_cast<const T&>(attr)); + } + bool compare(const future_std::any& lhs, const future_std::any& rhs) const override final { #ifdef PYBIND if (lhs.type() == typeid(py::object) && rhs.type() != typeid(py::object)) { @@ -385,6 +391,10 @@ struct DynamicAttributes::AnyUtils<py::object> : public DynamicAttributes::AnyUt return future_std::any_cast<const py::object&>(attr); } + std::string str(const future_std::any& attr) const override { + return py::str(future_std::any_cast<const py::object&>(attr)); + } + bool compare(const future_std::any& lhs, const future_std::any& rhs) const override { return (future_std::any_cast<py::object>(lhs) < future_std::any_cast<py::object>(rhs)); } @@ -423,6 +433,17 @@ namespace std { } }; + // Specialization of std::hash for std::pair<T1, T2> + template <typename T1, typename T2> + struct hash<std::pair<T1, T2>> { + std::size_t operator()(const std::pair<T1, T2>& p) const { + std::size_t seed = 0; + Aidge::hash_combine(seed, std::hash<std::remove_const_t<T1>>()(p.first)); + Aidge::hash_combine(seed, std::hash<T2>()(p.second)); + return seed; + } + }; + // General specialization of std::hash for any container that has iterators (e.g., std::vector, std::list, std::set) template <template <typename...> class Container, typename T, typename... Args> struct hash<Container<T, Args...>> { @@ -430,7 +451,8 @@ namespace std { std::size_t seed = 0; for (const auto& v : iterable) { // Recursively hash the value pointed by the iterator - Aidge::hash_combine(seed, std::hash<T>()(v)); + // Use decltype(v) instead of T to make it work for std::map for example. + Aidge::hash_combine(seed, std::hash<std::remove_const_t<std::remove_reference_t<decltype(v)>>>()(v)); } return seed; } @@ -448,21 +470,42 @@ namespace std { return seed; } }; - - // Specialization of std::hash for std::pair<T1, T2> - template <typename T1, typename T2> - struct hash<std::pair<T1, T2>> { - std::size_t operator()(const std::pair<T1, T2>& p) const { - std::size_t seed = 0; - Aidge::hash_combine(seed, std::hash<T1>()(p.first)); - Aidge::hash_combine(seed, std::hash<T2>()(p.second)); - return seed; - } - }; } namespace future_std { bool operator<(const future_std::any& lhs, const future_std::any& rhs); } + +template<typename T, typename Char> +struct fmt::formatter<T, Char, std::enable_if_t<std::is_convertible<T*, Aidge::Attributes*>::value>> { + template<typename ParseContext> + inline constexpr auto parse(ParseContext& ctx) -> decltype(ctx.begin()) { + return ctx.begin(); + } + + template<typename FormatContext> + inline auto format(const Aidge::Attributes& attrs, FormatContext& ctx) const { + std::string attrsStr; + for (const auto& attr : attrs.getAttrs()) { + const auto anyUtilsIt = Aidge::DynamicAttributes::mAnyUtils.find(attr.second.type()); + + if (attr.second.type() == typeid(Aidge::DynamicAttributes)) { + auto subAttrs = anyUtilsIt->second->str(attr.second); + subAttrs = std::regex_replace(subAttrs, std::regex("(^|\n)"), "$1 "); + attrsStr += fmt::format("{}:\n{}\n", attr.first, subAttrs); + } + else { + if (anyUtilsIt != Aidge::DynamicAttributes::mAnyUtils.end()) { + attrsStr += fmt::format("{} = {}\n", attr.first, anyUtilsIt->second->str(attr.second)); + } + else { + attrsStr += fmt::format("{} = ???\n", attr.first); + } + } + } + return fmt::format_to(ctx.out(), "{}", attrsStr); + } +}; + #endif /* AIDGE_CORE_UTILS_DYNAMICATTRIBUTES_H_ */ diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index e193a0af4..b38f92636 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -101,11 +101,17 @@ void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProd const auto namePtrTable = getRankedNodesName("{3}"); for (const std::shared_ptr<Node> &node_ptr : mNodes) { - const std::string hasCondition = (node_ptr->attributes()->hasAttr("schedule.cond")) ? " fa:fa-circle-question" : ""; + std::string attrs; + // Ignore name attribute (if size == 1) + if (node_ptr->attributes()->getAttrs().size() > 1) { + attrs = fmt::format(" <sup><span title=\"{}\" style=\"cursor: pointer; font-weight: bold; color: blue\">[{}]</span></sup>", + *node_ptr->attributes(), node_ptr->attributes()->getAttrs().size()); + } + std::string givenName = (node_ptr->name().empty()) - ? "<em>" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + "</em>" + hasCondition - : "\"" + node_ptr->name() + hasCondition + "<br/><sub><em>(" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + ")</em></sub>\""; + ? "\"<em>" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + "</em>" + attrs + "\"" + : "\"" + node_ptr->name() + attrs + "<br/><sub><em>(" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + ")</em></sub>\""; if (verbose) { givenName += "<br/><span style='color:white; background-color: purple; float: right'>" + node_ptr->getOperator()->backend() + "</span>"; diff --git a/src/graph/Testing.cpp b/src/graph/Testing.cpp index 774ee8912..1c63c4a0c 100644 --- a/src/graph/Testing.cpp +++ b/src/graph/Testing.cpp @@ -120,6 +120,14 @@ std::set<std::string> Aidge::nodePtrTo(const std::set<NodePtr>& nodes, return nodesStr; } +std::vector<std::string> Aidge::nodePtrTo(const std::vector<NodePtr>& nodes, + std::string(*nodeTo)(NodePtr)) +{ + std::vector<std::string> nodesStr; + std::transform(nodes.begin(), nodes.end(), std::inserter(nodesStr, nodesStr.begin()), nodeTo); + return nodesStr; +} + std::vector<std::pair<std::string, Aidge::IOIndex_t>> Aidge::nodePtrTo( const std::vector<std::pair<NodePtr, IOIndex_t>>& nodes, std::string(*nodeTo)(NodePtr)) diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 177975545..90222bdba 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -105,6 +105,56 @@ void Aidge::Scheduler::tagConditionalNodes() const { } } +void Aidge::Scheduler::tagForkBranches() const { + for (const auto& node : mGraphView->getNodes()) { + node->attributes()->delAttr("schedule.branch"); + } + + std::function<void(NodePtr, std::set<NodePtr>&)> recInBranch = [&recInBranch](NodePtr node, std::set<NodePtr>& branchNodes) { + bool inBranch = true; + for (const auto& parent : node->getParents()) { + if (branchNodes.find(parent) == branchNodes.end()) { + inBranch = false; + break; + } + } + + if (inBranch) { + branchNodes.insert(node); + for (const auto& child : node->getChildren()) { + recInBranch(child, branchNodes); + } + } + }; + + for (const auto& node : mGraphView->getNodes()) { + if (node->getChildren().size() > 1) { + // In more than 1 child, it is a fork branch + size_t branch = 0; + for (auto childs : node->getOrderedChildren()) { + for (auto child : childs) { + std::set<NodePtr> branchNodes; + branchNodes.insert(node); + recInBranch(child, branchNodes); + branchNodes.erase(node); + + for (const auto& branchNode : branchNodes) { + std::map<NodePtr, size_t> attr; + if (branchNode->attributes()->hasAttr("schedule.branch")) { + attr = branchNode->attributes()->getAttr<std::map<NodePtr, size_t>>("schedule.branch"); + } + + attr.insert({node, branch}); + branchNode->attributes()->setAttr<std::map<NodePtr, size_t>>("schedule.branch", attr); + } + + ++branch; + } + } + } + } +} + bool Aidge::Scheduler::isConditionalNodeRequired(NodePtr node) const { bool skip = false; if (node->attributes()->hasAttr("schedule.cond")) { @@ -1185,6 +1235,51 @@ std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getSequentialStaticS std::stable_sort(staticSchedule.begin(), staticSchedule.end(), [](const auto& lhs, const auto& rhs) { return ((lhs->late < rhs->late) || (lhs->late == rhs->late && lhs->early > rhs->early)); }); } + else if (policy == SchedulingPolicy::ShortestBranchFirst || policy == SchedulingPolicy::LonguestBranchFirst) { + for (std::size_t elt = 0; elt < staticSchedule.size(); ++elt) { + const auto node = staticSchedule[elt]->node; + + if (node->getChildren().size() > 1) { + // The node is a fork: isolate branches + std::vector<std::deque<StaticSchedulingElement*>> branches(node->getChildren().size()); + + for (std::size_t branchElt = 0; branchElt < staticSchedule.size(); ++branchElt) { + const auto branchNode = staticSchedule[branchElt]->node; + if (branchNode->attributes()->hasAttr("schedule.branch")) { + auto attr = branchNode->attributes()->getAttr<std::map<NodePtr, size_t>>("schedule.branch"); + const auto it = attr.find(node); + if (it != attr.end()) { + branches[it->second].push_back(staticSchedule[branchElt]); + } + else { + // The next scheduled element is not part of one of the originally forked branch + break; + } + } + } + + // Sort branches + if (policy == SchedulingPolicy::ShortestBranchFirst) { + std::stable_sort(branches.begin(), branches.end(), + [](const auto& lhs, const auto& rhs) { return lhs.size() < rhs.size(); }); + } + else { + std::stable_sort(branches.begin(), branches.end(), + [](const auto& lhs, const auto& rhs) { return lhs.size() > rhs.size(); }); + } + + // Flatten the schedule + std::size_t offset = elt + 1; + + for (std::size_t branch = 0; branch < branches.size(); ++branch) { + std::copy(branches[branch].begin(), branches[branch].end(), staticSchedule.begin() + offset); + offset += branches[branch].size(); + } + + // Move to next element (to properly handle branches in branches) + } + } + } std::vector<std::shared_ptr<Node>> schedule; std::transform(staticSchedule.begin(), staticSchedule.end(), std::back_inserter(schedule), [](const auto& v) { return v->node; }); -- GitLab