From c3318428b1855e94abdd703e371bbffe42368bbe Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Sun, 16 Feb 2025 23:46:37 +0100
Subject: [PATCH] Working concept of with tagConditionalNodes()

---
 include/aidge/scheduler/Scheduler.hpp     |   2 +
 include/aidge/utils/DynamicAttributes.hpp |   5 +-
 src/scheduler/Scheduler.cpp               | 129 ++++++++++++++++++++++
 3 files changed, 135 insertions(+), 1 deletion(-)

diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp
index dfdc270fa..ed9db47b3 100644
--- a/include/aidge/scheduler/Scheduler.hpp
+++ b/include/aidge/scheduler/Scheduler.hpp
@@ -127,6 +127,8 @@ public:
     virtual ~Scheduler();
 
 public:
+    void tagConditionalNodes();
+
     /**
      * @brief Get the static scheduling order of nodes.
      * @param step The step of the static schedule to retrieve (default is 0).
diff --git a/include/aidge/utils/DynamicAttributes.hpp b/include/aidge/utils/DynamicAttributes.hpp
index 6ac76c138..633ce40d9 100644
--- a/include/aidge/utils/DynamicAttributes.hpp
+++ b/include/aidge/utils/DynamicAttributes.hpp
@@ -165,7 +165,10 @@ public:
         else {
             const auto ns = name.substr(0, dot);
             const auto nsName = name.substr(dot + 1);
-            future_std::any_cast<DynamicAttributes&>(mAttrs.at(ns)).delAttr(nsName);
+            auto it = mAttrs.find(ns);
+            if (it != mAttrs.end()) {
+                future_std::any_cast<DynamicAttributes&>(it->second).delAttr(nsName);
+            }
         }
     }
 
diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp
index fabdc7ad2..3f59f2fdc 100644
--- a/src/scheduler/Scheduler.cpp
+++ b/src/scheduler/Scheduler.cpp
@@ -56,6 +56,54 @@ void Aidge::Scheduler::generateScheduling() {
     mStaticSchedule.push_back(schedule);
 }
 
+void Aidge::Scheduler::tagConditionalNodes() {
+    // Get a list of selectors
+    std::vector<NodePtr> selectors;
+    for (const auto& node : mGraphView->getNodes()) {
+        if (node->type() == "Select") {
+            selectors.push_back(node);
+        }
+        node->attributes()->delAttr("schedule.cond");
+    }
+
+    std::function<void(NodePtr, std::set<NodePtr>&)> recInBranch = [&recInBranch](NodePtr node, std::set<NodePtr>& branchNodes) {
+        bool inBranch = true;
+        for (const auto& child : node->getChildren()) {
+            if (branchNodes.find(child) == branchNodes.end()) {
+                inBranch = false;
+                break;
+            }
+        }
+
+        if (inBranch) {
+            branchNodes.insert(node);
+            for (const auto& parent : node->getParents()) {
+                recInBranch(parent, branchNodes);
+            }
+        }
+    };
+
+    // For each selector, tag nodes
+    for (const auto& select : selectors) {
+        for (size_t branch = 0; branch < select->getParents().size() - 1; ++branch) {
+            std::set<NodePtr> branchNodes;
+            branchNodes.insert(select);
+            recInBranch(select->getParent(branch + 1), branchNodes);
+            branchNodes.erase(select);
+
+            for (const auto& node : branchNodes) {
+                std::set<std::pair<NodePtr, size_t>> attr;
+                if (node->attributes()->hasAttr("schedule.cond")) {
+                    attr = node->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond");
+                }
+
+                attr.insert({select, branch});
+                node->attributes()->setAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond", attr);
+            }
+        }
+    }
+}
+
 std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::generateBaseScheduling() const {
 
     // 0) setup useful variables
@@ -182,6 +230,22 @@ std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::genera
                 }
             }
 
+            if (consumer->attributes()->hasAttr("schedule.cond")) {
+                auto attr = consumer->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond");
+        
+                for (const auto& cond : attr) {
+                    const auto& select = cond.first;
+                    AvailableDataStatus status;
+
+                    if ((select->getOperator()->getNbConsumedData(0) + select->getOperator()->getNbRequiredData(0)) >
+                        getNbAvailableData(select, 0, status))
+                    {
+                        isRunnable = false;
+                        break;
+                    }
+                }
+            }
+
             if (isRunnable) {
                 runnableConsumers.insert(consumer);
             }
@@ -386,6 +450,23 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE
             }
         }
 
+        if (node->attributes()->hasAttr("schedule.cond")) {
+            auto attr = node->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond");
+    
+            for (const auto& cond : attr) {
+                const auto& select = cond.first;
+                const auto& parent = select->input(0).first;
+        
+                const auto it = std::find_if(schedule.rend() - elt, schedule.rend(),
+                    [parent](const auto& v) { return (v->node == parent); });
+                if (it != schedule.rend()) {
+                    const std::size_t step = std::distance(schedule.begin(), it.base()) - 1;
+                    early = std::max(early, schedule[step]->early + 1);
+                    schedule[step]->earlierThan.push_back(schedule[elt]);
+                }
+            }
+        }
+
         latest = std::max(latest, early);
         schedule[elt]->early = early;
     }
@@ -421,8 +502,32 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE
                 late = std::min(late, schedule[step]->late - 1);
                 schedule[step]->laterThan.push_back(schedule[elt]);
             }
+
+            if (child->type() == "Select") {
+                for (const auto& condNode : mGraphView->getNodes()) {
+                    if (condNode->attributes()->hasAttr("schedule.cond")) {
+                        auto attr = condNode->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond");
+
+                        for (const auto& cond : attr) {
+                            const auto& select = cond.first;
+
+                            if (node == select->input(0).first) {
+                                const auto it = std::find_if(schedule.begin() + elt + 1, schedule.end(),
+                                    [condNode](const auto& v) { return (v->node == condNode); });
+                                if (it != schedule.end()) {
+                                    const std::size_t step = std::distance(schedule.begin(), it);
+                                    late = std::min(late, schedule[step]->late - 1);
+                                    schedule[step]->laterThan.push_back(schedule[elt]);
+                                }
+                            }
+                        }
+                    }
+                }
+            }
         }
 
+        // TODO: ADD HERE SCHEDULE COND
+
         schedule[elt]->late = late;
     }
 }
@@ -1148,6 +1253,30 @@ Aidge::Scheduler::getPriorProducersConsumers(const std::shared_ptr<Node>& node)
         ++inputIdx;
     }
 
+    if (node->attributes()->hasAttr("schedule.cond")) {
+        auto attr = node->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond");
+
+        for (const auto& cond : attr) {
+            const auto& select = cond.first;
+            const auto& parent = select->input(0);
+
+            if ((select->getOperator()->getNbConsumedData(0) + select->getOperator()->getNbRequiredData(0)) >
+                parent.first->getOperator()->getNbProducedData(parent.second))
+            {
+                const auto& parentPrior = getPriorProducersConsumers(parent.first);
+
+                if (!parentPrior.isPrior) {
+                    // only happens in case of cyclic graphs
+                    return PriorProducersConsumers(); // not scheduled
+                }
+                else {
+                    prior.requiredProducers.insert(parentPrior.requiredProducers.cbegin(), parentPrior.requiredProducers.cend());
+                    prior.priorConsumers.insert(parentPrior.priorConsumers.cbegin(), parentPrior.priorConsumers.cend());
+                }
+            }
+        }
+    }
+
     prior.isPrior = true;
     if (node->type() == Producer_Op::Type) {
         prior.requiredProducers.insert(node);
-- 
GitLab