From c3318428b1855e94abdd703e371bbffe42368bbe Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Sun, 16 Feb 2025 23:46:37 +0100 Subject: [PATCH] Working concept of with tagConditionalNodes() --- include/aidge/scheduler/Scheduler.hpp | 2 + include/aidge/utils/DynamicAttributes.hpp | 5 +- src/scheduler/Scheduler.cpp | 129 ++++++++++++++++++++++ 3 files changed, 135 insertions(+), 1 deletion(-) diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index dfdc270fa..ed9db47b3 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -127,6 +127,8 @@ public: virtual ~Scheduler(); public: + void tagConditionalNodes(); + /** * @brief Get the static scheduling order of nodes. * @param step The step of the static schedule to retrieve (default is 0). diff --git a/include/aidge/utils/DynamicAttributes.hpp b/include/aidge/utils/DynamicAttributes.hpp index 6ac76c138..633ce40d9 100644 --- a/include/aidge/utils/DynamicAttributes.hpp +++ b/include/aidge/utils/DynamicAttributes.hpp @@ -165,7 +165,10 @@ public: else { const auto ns = name.substr(0, dot); const auto nsName = name.substr(dot + 1); - future_std::any_cast<DynamicAttributes&>(mAttrs.at(ns)).delAttr(nsName); + auto it = mAttrs.find(ns); + if (it != mAttrs.end()) { + future_std::any_cast<DynamicAttributes&>(it->second).delAttr(nsName); + } } } diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index fabdc7ad2..3f59f2fdc 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -56,6 +56,54 @@ void Aidge::Scheduler::generateScheduling() { mStaticSchedule.push_back(schedule); } +void Aidge::Scheduler::tagConditionalNodes() { + // Get a list of selectors + std::vector<NodePtr> selectors; + for (const auto& node : mGraphView->getNodes()) { + if (node->type() == "Select") { + selectors.push_back(node); + } + node->attributes()->delAttr("schedule.cond"); + } + + std::function<void(NodePtr, std::set<NodePtr>&)> recInBranch = [&recInBranch](NodePtr node, std::set<NodePtr>& branchNodes) { + bool inBranch = true; + for (const auto& child : node->getChildren()) { + if (branchNodes.find(child) == branchNodes.end()) { + inBranch = false; + break; + } + } + + if (inBranch) { + branchNodes.insert(node); + for (const auto& parent : node->getParents()) { + recInBranch(parent, branchNodes); + } + } + }; + + // For each selector, tag nodes + for (const auto& select : selectors) { + for (size_t branch = 0; branch < select->getParents().size() - 1; ++branch) { + std::set<NodePtr> branchNodes; + branchNodes.insert(select); + recInBranch(select->getParent(branch + 1), branchNodes); + branchNodes.erase(select); + + for (const auto& node : branchNodes) { + std::set<std::pair<NodePtr, size_t>> attr; + if (node->attributes()->hasAttr("schedule.cond")) { + attr = node->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond"); + } + + attr.insert({select, branch}); + node->attributes()->setAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond", attr); + } + } + } +} + std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::generateBaseScheduling() const { // 0) setup useful variables @@ -182,6 +230,22 @@ std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::genera } } + if (consumer->attributes()->hasAttr("schedule.cond")) { + auto attr = consumer->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond"); + + for (const auto& cond : attr) { + const auto& select = cond.first; + AvailableDataStatus status; + + if ((select->getOperator()->getNbConsumedData(0) + select->getOperator()->getNbRequiredData(0)) > + getNbAvailableData(select, 0, status)) + { + isRunnable = false; + break; + } + } + } + if (isRunnable) { runnableConsumers.insert(consumer); } @@ -386,6 +450,23 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE } } + if (node->attributes()->hasAttr("schedule.cond")) { + auto attr = node->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond"); + + for (const auto& cond : attr) { + const auto& select = cond.first; + const auto& parent = select->input(0).first; + + const auto it = std::find_if(schedule.rend() - elt, schedule.rend(), + [parent](const auto& v) { return (v->node == parent); }); + if (it != schedule.rend()) { + const std::size_t step = std::distance(schedule.begin(), it.base()) - 1; + early = std::max(early, schedule[step]->early + 1); + schedule[step]->earlierThan.push_back(schedule[elt]); + } + } + } + latest = std::max(latest, early); schedule[elt]->early = early; } @@ -421,8 +502,32 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE late = std::min(late, schedule[step]->late - 1); schedule[step]->laterThan.push_back(schedule[elt]); } + + if (child->type() == "Select") { + for (const auto& condNode : mGraphView->getNodes()) { + if (condNode->attributes()->hasAttr("schedule.cond")) { + auto attr = condNode->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond"); + + for (const auto& cond : attr) { + const auto& select = cond.first; + + if (node == select->input(0).first) { + const auto it = std::find_if(schedule.begin() + elt + 1, schedule.end(), + [condNode](const auto& v) { return (v->node == condNode); }); + if (it != schedule.end()) { + const std::size_t step = std::distance(schedule.begin(), it); + late = std::min(late, schedule[step]->late - 1); + schedule[step]->laterThan.push_back(schedule[elt]); + } + } + } + } + } + } } + // TODO: ADD HERE SCHEDULE COND + schedule[elt]->late = late; } } @@ -1148,6 +1253,30 @@ Aidge::Scheduler::getPriorProducersConsumers(const std::shared_ptr<Node>& node) ++inputIdx; } + if (node->attributes()->hasAttr("schedule.cond")) { + auto attr = node->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond"); + + for (const auto& cond : attr) { + const auto& select = cond.first; + const auto& parent = select->input(0); + + if ((select->getOperator()->getNbConsumedData(0) + select->getOperator()->getNbRequiredData(0)) > + parent.first->getOperator()->getNbProducedData(parent.second)) + { + const auto& parentPrior = getPriorProducersConsumers(parent.first); + + if (!parentPrior.isPrior) { + // only happens in case of cyclic graphs + return PriorProducersConsumers(); // not scheduled + } + else { + prior.requiredProducers.insert(parentPrior.requiredProducers.cbegin(), parentPrior.requiredProducers.cend()); + prior.priorConsumers.insert(parentPrior.priorConsumers.cbegin(), parentPrior.priorConsumers.cend()); + } + } + } + } + prior.isPrior = true; if (node->type() == Producer_Op::Type) { prior.requiredProducers.insert(node); -- GitLab