From a2705ba1fa60b27f23ffd11b53240bafa918d1ad Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Tue, 4 Mar 2025 15:39:31 +0100
Subject: [PATCH] Added new scheduling policies

---
 include/aidge/graph/Node.hpp              | 15 ++++
 include/aidge/graph/Testing.hpp           |  2 +
 include/aidge/operator/Operator.hpp       | 17 ++++
 include/aidge/scheduler/Scheduler.hpp     |  9 ++-
 include/aidge/utils/Attributes.hpp        | 16 ++++
 include/aidge/utils/DynamicAttributes.hpp | 67 +++++++++++++---
 src/graph/GraphView.cpp                   | 12 ++-
 src/graph/Testing.cpp                     |  8 ++
 src/scheduler/Scheduler.cpp               | 95 +++++++++++++++++++++++
 9 files changed, 225 insertions(+), 16 deletions(-)

diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp
index cd2ca38df..b270d1474 100644
--- a/include/aidge/graph/Node.hpp
+++ b/include/aidge/graph/Node.hpp
@@ -542,5 +542,20 @@ private:
 
 } // namespace Aidge
 
+template <>
+struct fmt::formatter<Aidge::Node> : formatter<const char*> {
+    template <typename FormatContext>
+    auto format(const Aidge::Node& node, FormatContext& ctx) const {
+        return fmt::format_to(ctx.out(), "n.{}", node.getOperator());
+    }
+};
+
+template <>
+struct fmt::formatter<std::shared_ptr<Aidge::Node>> : formatter<const char*> {
+    template <typename FormatContext>
+    auto format(const std::shared_ptr<Aidge::Node>& node, FormatContext& ctx) const {
+        return fmt::format_to(ctx.out(), "n.{}", node->getOperator());
+    }
+};
 
 #endif /* AIDGE_CORE_GRAPH_NODE_H_ */
diff --git a/include/aidge/graph/Testing.hpp b/include/aidge/graph/Testing.hpp
index ecacdf662..9ead0b406 100644
--- a/include/aidge/graph/Testing.hpp
+++ b/include/aidge/graph/Testing.hpp
@@ -58,6 +58,8 @@ std::string nodePtrToType(NodePtr node);
 std::string nodePtrToName(NodePtr node);
 std::set<std::string> nodePtrTo(const std::set<NodePtr>& nodes,
     std::string(*nodeTo)(NodePtr) = nodePtrToType);
+std::vector<std::string> nodePtrTo(const std::vector<NodePtr>& nodes,
+    std::string(*nodeTo)(NodePtr) = nodePtrToType);
 std::vector<std::pair<std::string, IOIndex_t>> nodePtrTo(
     const std::vector<std::pair<NodePtr, IOIndex_t>>& nodes,
     std::string(*nodeTo)(NodePtr) = nodePtrToType);
diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp
index 5a12cfea2..3797546b9 100644
--- a/include/aidge/operator/Operator.hpp
+++ b/include/aidge/operator/Operator.hpp
@@ -377,4 +377,21 @@ public:
 
 }  // namespace Aidge
 
+
+template <>
+struct fmt::formatter<Aidge::Operator> : formatter<const char*> {
+    template <typename FormatContext>
+    auto format(const Aidge::Operator& op, FormatContext& ctx) const {
+        return fmt::format_to(ctx.out(), "op.{}", op.type());
+    }
+};
+
+template <>
+struct fmt::formatter<std::shared_ptr<Aidge::Operator>> : formatter<const char*> {
+    template <typename FormatContext>
+    auto format(const std::shared_ptr<Aidge::Operator>& op, FormatContext& ctx) const {
+        return fmt::format_to(ctx.out(), "op.{}", op->type());
+    }
+};
+
 #endif  // AIDGE_CORE_OPERATOR_OPERATOR_H_
diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp
index 7c309783d..1c269cbb4 100644
--- a/include/aidge/scheduler/Scheduler.hpp
+++ b/include/aidge/scheduler/Scheduler.hpp
@@ -93,7 +93,9 @@ public:
     enum class SchedulingPolicy {
         Default,
         AsSoonAsPossible,
-        AsLateAsPossible
+        AsLateAsPossible,
+        ShortestBranchFirst,
+        LonguestBranchFirst
     };
 
     /**
@@ -135,6 +137,11 @@ public:
      */
     void tagConditionalNodes() const;
 
+    /**
+     * @brief Add schedule.branch attribute to nodes.
+     */
+    void tagForkBranches() const;
+
     /**
      * @brief Get the static scheduling (after generate scheduling).
      * @return Vector of StaticSchedulingElement pointers.
diff --git a/include/aidge/utils/Attributes.hpp b/include/aidge/utils/Attributes.hpp
index e25485fe0..123e36a59 100644
--- a/include/aidge/utils/Attributes.hpp
+++ b/include/aidge/utils/Attributes.hpp
@@ -12,10 +12,13 @@
 #ifndef AIDGE_CORE_UTILS_ATTRIBUTES_H_
 #define AIDGE_CORE_UTILS_ATTRIBUTES_H_
 
+#include <memory>
 #include <string>
 #include <set>
 #include <map>
 
+#include <fmt/format.h>
+
 #include "aidge/utils/future_std/any.hpp"
 
 #ifdef PYBIND
@@ -84,4 +87,17 @@ public:
 };
 }
 
+template<>
+struct fmt::formatter<std::shared_ptr<Aidge::Attributes>> {
+    template<typename ParseContext>
+    inline constexpr auto parse(ParseContext& ctx) -> decltype(ctx.begin()) {
+        return ctx.begin();
+    }
+
+    template<typename FormatContext>
+    inline auto format(const std::shared_ptr<Aidge::Attributes>& attrs, FormatContext& ctx) const {
+        return fmt::format_to(ctx.out(), "{}", *attrs);
+    }
+};
+
 #endif /* AIDGE_CORE_UTILS_ATTRIBUTES_H_ */
diff --git a/include/aidge/utils/DynamicAttributes.hpp b/include/aidge/utils/DynamicAttributes.hpp
index 633ce40d9..be3bc1a97 100644
--- a/include/aidge/utils/DynamicAttributes.hpp
+++ b/include/aidge/utils/DynamicAttributes.hpp
@@ -19,6 +19,7 @@
 #include <cassert>
 #include <string>
 #include <typeindex>
+#include <regex>
 
 #include "aidge/utils/future_std/any.hpp"
 #include "aidge/utils/Attributes.hpp"
@@ -331,6 +332,7 @@ public:
 #ifdef PYBIND
         virtual py::object cast(const future_std::any& attr) const = 0;
 #endif
+        virtual std::string str(const future_std::any& attr) const = 0;
         virtual bool compare(const future_std::any&, const future_std::any&) const = 0;
         virtual size_t hash(const future_std::any&) const = 0;
         virtual ~AnyUtils_() = default;
@@ -344,6 +346,10 @@ public:
         }
 #endif
 
+        std::string str(const future_std::any& attr) const override final {
+            return fmt::format("{}", future_std::any_cast<const T&>(attr));
+        }
+
         bool compare(const future_std::any& lhs, const future_std::any& rhs) const override final {
 #ifdef PYBIND
             if (lhs.type() == typeid(py::object) && rhs.type() != typeid(py::object)) {
@@ -385,6 +391,10 @@ struct DynamicAttributes::AnyUtils<py::object> : public DynamicAttributes::AnyUt
         return future_std::any_cast<const py::object&>(attr);
     }
 
+    std::string str(const future_std::any& attr) const override {
+        return py::str(future_std::any_cast<const py::object&>(attr));
+    }
+
     bool compare(const future_std::any& lhs, const future_std::any& rhs) const override {
         return (future_std::any_cast<py::object>(lhs) < future_std::any_cast<py::object>(rhs));
     }
@@ -423,6 +433,17 @@ namespace std {
         }
     };
 
+    // Specialization of std::hash for std::pair<T1, T2>
+    template <typename T1, typename T2>
+    struct hash<std::pair<T1, T2>> {
+        std::size_t operator()(const std::pair<T1, T2>& p) const {
+            std::size_t seed = 0;
+            Aidge::hash_combine(seed, std::hash<std::remove_const_t<T1>>()(p.first));
+            Aidge::hash_combine(seed, std::hash<T2>()(p.second));
+            return seed;
+        }
+    };
+
     // General specialization of std::hash for any container that has iterators (e.g., std::vector, std::list, std::set)
     template <template <typename...> class Container, typename T, typename... Args>
     struct hash<Container<T, Args...>> {
@@ -430,7 +451,8 @@ namespace std {
             std::size_t seed = 0;
             for (const auto& v : iterable) {
                 // Recursively hash the value pointed by the iterator
-                Aidge::hash_combine(seed, std::hash<T>()(v));
+                // Use decltype(v) instead of T to make it work for std::map for example.
+                Aidge::hash_combine(seed, std::hash<std::remove_const_t<std::remove_reference_t<decltype(v)>>>()(v));
             }
             return seed;
         }
@@ -448,21 +470,42 @@ namespace std {
             return seed;
         }
     };
-
-    // Specialization of std::hash for std::pair<T1, T2>
-    template <typename T1, typename T2>
-    struct hash<std::pair<T1, T2>> {
-        std::size_t operator()(const std::pair<T1, T2>& p) const {
-            std::size_t seed = 0;
-            Aidge::hash_combine(seed, std::hash<T1>()(p.first));
-            Aidge::hash_combine(seed, std::hash<T2>()(p.second));
-            return seed;
-        }
-    };
 }
 
 namespace future_std {
 bool operator<(const future_std::any& lhs, const future_std::any& rhs);
 }
 
+
+template<typename T, typename Char>
+struct fmt::formatter<T, Char,  std::enable_if_t<std::is_convertible<T*, Aidge::Attributes*>::value>> {
+    template<typename ParseContext>
+    inline constexpr auto parse(ParseContext& ctx) -> decltype(ctx.begin()) {
+        return ctx.begin();
+    }
+
+    template<typename FormatContext>
+    inline auto format(const Aidge::Attributes& attrs, FormatContext& ctx) const {
+        std::string attrsStr;
+        for (const auto& attr : attrs.getAttrs()) {
+            const auto anyUtilsIt = Aidge::DynamicAttributes::mAnyUtils.find(attr.second.type());
+
+            if (attr.second.type() == typeid(Aidge::DynamicAttributes)) {
+                auto subAttrs = anyUtilsIt->second->str(attr.second);
+                subAttrs = std::regex_replace(subAttrs, std::regex("(^|\n)"), "$1  "); 
+                attrsStr += fmt::format("{}:\n{}\n", attr.first, subAttrs);
+            }
+            else {
+                if (anyUtilsIt != Aidge::DynamicAttributes::mAnyUtils.end()) {
+                    attrsStr += fmt::format("{} = {}\n", attr.first, anyUtilsIt->second->str(attr.second));
+                }
+                else {
+                    attrsStr += fmt::format("{} = ???\n", attr.first);
+                }
+            }
+        }
+        return fmt::format_to(ctx.out(), "{}", attrsStr);
+    }
+};
+
 #endif /* AIDGE_CORE_UTILS_DYNAMICATTRIBUTES_H_ */
diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp
index e193a0af4..b38f92636 100644
--- a/src/graph/GraphView.cpp
+++ b/src/graph/GraphView.cpp
@@ -101,11 +101,17 @@ void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProd
     const auto namePtrTable = getRankedNodesName("{3}");
 
     for (const std::shared_ptr<Node> &node_ptr : mNodes) {
-        const std::string hasCondition = (node_ptr->attributes()->hasAttr("schedule.cond")) ? " fa:fa-circle-question" : "";
+        std::string attrs;
+        // Ignore name attribute (if size == 1)
+        if (node_ptr->attributes()->getAttrs().size() > 1) {
+          attrs = fmt::format("&nbsp;<sup><span title=\"{}\" style=\"cursor: pointer; font-weight: bold; color: blue\">[{}]</span></sup>",
+            *node_ptr->attributes(), node_ptr->attributes()->getAttrs().size());
+        }
+
         std::string givenName =
             (node_ptr->name().empty())
-                ? "<em>" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + "</em>" + hasCondition
-                : "\"" + node_ptr->name() + hasCondition + "<br/><sub><em>(" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + ")</em></sub>\"";
+                ? "\"<em>" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + "</em>" + attrs + "\""
+                : "\"" + node_ptr->name() + attrs + "<br/><sub><em>(" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + ")</em></sub>\"";
 
         if (verbose) {
           givenName += "<br/><span style='color:white; background-color: purple; float: right'>" + node_ptr->getOperator()->backend() + "</span>";
diff --git a/src/graph/Testing.cpp b/src/graph/Testing.cpp
index 774ee8912..1c63c4a0c 100644
--- a/src/graph/Testing.cpp
+++ b/src/graph/Testing.cpp
@@ -120,6 +120,14 @@ std::set<std::string> Aidge::nodePtrTo(const std::set<NodePtr>& nodes,
     return nodesStr;
 }
 
+std::vector<std::string> Aidge::nodePtrTo(const std::vector<NodePtr>& nodes,
+    std::string(*nodeTo)(NodePtr))
+{
+    std::vector<std::string> nodesStr;
+    std::transform(nodes.begin(), nodes.end(), std::inserter(nodesStr, nodesStr.begin()), nodeTo);
+    return nodesStr;
+}
+
 std::vector<std::pair<std::string, Aidge::IOIndex_t>> Aidge::nodePtrTo(
     const std::vector<std::pair<NodePtr, IOIndex_t>>& nodes,
     std::string(*nodeTo)(NodePtr))
diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp
index 177975545..90222bdba 100644
--- a/src/scheduler/Scheduler.cpp
+++ b/src/scheduler/Scheduler.cpp
@@ -105,6 +105,56 @@ void Aidge::Scheduler::tagConditionalNodes() const {
     }
 }
 
+void Aidge::Scheduler::tagForkBranches() const {
+    for (const auto& node : mGraphView->getNodes()) {
+        node->attributes()->delAttr("schedule.branch");
+    }
+
+    std::function<void(NodePtr, std::set<NodePtr>&)> recInBranch = [&recInBranch](NodePtr node, std::set<NodePtr>& branchNodes) {
+        bool inBranch = true;
+        for (const auto& parent : node->getParents()) {
+            if (branchNodes.find(parent) == branchNodes.end()) {
+                inBranch = false;
+                break;
+            }
+        }
+
+        if (inBranch) {
+            branchNodes.insert(node);
+            for (const auto& child : node->getChildren()) {
+                recInBranch(child, branchNodes);
+            }
+        }
+    };
+
+    for (const auto& node : mGraphView->getNodes()) {
+        if (node->getChildren().size() > 1) {
+            // In more than 1 child, it is a fork branch
+            size_t branch = 0;
+            for (auto childs : node->getOrderedChildren()) {
+                for (auto child : childs) {
+                    std::set<NodePtr> branchNodes;
+                    branchNodes.insert(node);
+                    recInBranch(child, branchNodes);
+                    branchNodes.erase(node);
+
+                    for (const auto& branchNode : branchNodes) {
+                        std::map<NodePtr, size_t> attr;
+                        if (branchNode->attributes()->hasAttr("schedule.branch")) {
+                            attr = branchNode->attributes()->getAttr<std::map<NodePtr, size_t>>("schedule.branch");
+                        }
+
+                        attr.insert({node, branch});
+                        branchNode->attributes()->setAttr<std::map<NodePtr, size_t>>("schedule.branch", attr);
+                    }
+
+                    ++branch;
+                }
+            }
+        }
+    }
+}
+
 bool Aidge::Scheduler::isConditionalNodeRequired(NodePtr node) const {
     bool skip = false;
     if (node->attributes()->hasAttr("schedule.cond")) {
@@ -1185,6 +1235,51 @@ std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getSequentialStaticS
         std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
             [](const auto& lhs, const auto& rhs) { return ((lhs->late < rhs->late) || (lhs->late == rhs->late && lhs->early > rhs->early)); });
     }
+    else if (policy == SchedulingPolicy::ShortestBranchFirst || policy == SchedulingPolicy::LonguestBranchFirst) {
+        for (std::size_t elt = 0; elt < staticSchedule.size(); ++elt) {
+            const auto node = staticSchedule[elt]->node;
+
+            if (node->getChildren().size() > 1) {
+                // The node is a fork: isolate branches
+                std::vector<std::deque<StaticSchedulingElement*>> branches(node->getChildren().size());
+
+                for (std::size_t branchElt = 0; branchElt < staticSchedule.size(); ++branchElt) {
+                    const auto branchNode = staticSchedule[branchElt]->node;
+                    if (branchNode->attributes()->hasAttr("schedule.branch")) {
+                        auto attr = branchNode->attributes()->getAttr<std::map<NodePtr, size_t>>("schedule.branch");
+                        const auto it = attr.find(node);
+                        if (it != attr.end()) {
+                            branches[it->second].push_back(staticSchedule[branchElt]);
+                        }
+                        else {
+                            // The next scheduled element is not part of one of the originally forked branch
+                            break;
+                        }
+                    }
+                }
+
+                // Sort branches
+                if (policy == SchedulingPolicy::ShortestBranchFirst) {
+                    std::stable_sort(branches.begin(), branches.end(), 
+                        [](const auto& lhs, const auto& rhs) { return lhs.size() < rhs.size(); });
+                }
+                else {
+                    std::stable_sort(branches.begin(), branches.end(), 
+                        [](const auto& lhs, const auto& rhs) { return lhs.size() > rhs.size(); });
+                }
+
+                // Flatten the schedule
+                std::size_t offset = elt + 1;
+
+                for (std::size_t branch = 0; branch < branches.size(); ++branch) {
+                    std::copy(branches[branch].begin(), branches[branch].end(), staticSchedule.begin() + offset);
+                    offset += branches[branch].size();
+                }
+
+                // Move to next element (to properly handle branches in branches)
+            }
+        }
+    }
 
     std::vector<std::shared_ptr<Node>> schedule;
     std::transform(staticSchedule.begin(), staticSchedule.end(), std::back_inserter(schedule), [](const auto& v) { return v->node; });
-- 
GitLab