diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index a164f6c76663e7227533bf5615335e9d228dd15a..881d16e05f1fd54ea83d17add57e59e705529d35 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -131,18 +131,19 @@ public: * @brief Add schedule.cond attribute to conditional nodes. * The schedule.cond attribute is a `std::set<std::pair<NodePtr, size_t>>`, * where the first element is the Select node and the second element, the - * Select input index. + * Select input index (starting from 0, ignoring the condition input). */ void tagConditionalNodes() const; /** - * @brief Check if the node condition is valid. + * @brief Check if the conditional node is required (if one of its conditions + * is true). * * @param node Node to check the condition. - * @return true If the node condition is valid, meaning it has to be executed. - * @return false If the node condition is not valid, meaning it can be skipped. + * @return true If any node condition is true, meaning it has to be executed. + * @return false If all node conditions are false, meaning it can be skipped. */ - bool isNodeCondValid(NodePtr node) const; + bool isConditionalNodeRequired(NodePtr node) const; /** * @brief Get the static scheduling order of nodes. diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 315844858103cbce91049ec2195ff0a3bd7a9d81..dd17cd34447ce208a4cd0dd00d2b05a8bee1f590 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -100,10 +100,11 @@ void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProd const auto namePtrTable = getRankedNodesName("{3}"); for (const std::shared_ptr<Node> &node_ptr : mNodes) { + const std::string hasCondition = (node_ptr->attributes()->hasAttr("schedule.cond")) ? " fa:fa-circle-question" : ""; std::string givenName = (node_ptr->name().empty()) - ? "<em>" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + "</em>" - : "\"" + node_ptr->name() + "<br/><sub><em>(" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + ")</em></sub>\""; + ? "<em>" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + "</em>" + hasCondition + : "\"" + node_ptr->name() + hasCondition + "<br/><sub><em>(" + node_ptr->type() + "#" + namePtrTable.at(node_ptr) + ")</em></sub>\""; if (verbose) { givenName += "<br/><span style='color:white; background-color: purple; float: right'>" + node_ptr->getOperator()->backend() + "</span>"; diff --git a/src/scheduler/ParallelScheduler.cpp b/src/scheduler/ParallelScheduler.cpp index 8e53254d71cc84e1d92932d0fe1e0d40ae999232..fb0d45c94438500fd03b0b2443b9162934e0d320 100644 --- a/src/scheduler/ParallelScheduler.cpp +++ b/src/scheduler/ParallelScheduler.cpp @@ -88,7 +88,7 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std:: // Add the critical node to the thread pool queue, to be run ASAP finished[runnable] = false; pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() { - if (!isNodeCondValid(node)) { + if (!isConditionalNodeRequired(node)) { finished = true; return; } @@ -149,7 +149,7 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std:: // Add the node to the thread pool queue, to be run ASAP finished[runnable] = false; pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() { - if (!isNodeCondValid(node)) { + if (!isConditionalNodeRequired(node)) { finished = true; return; } diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index bc0b19cfe317ebab0857fc25164d23a103dcdfb7..18210a1f8fb0e9f2474e90279ab413668a592e29 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -104,7 +104,7 @@ void Aidge::Scheduler::tagConditionalNodes() const { } } -bool Aidge::Scheduler::isNodeCondValid(NodePtr node) const { +bool Aidge::Scheduler::isConditionalNodeRequired(NodePtr node) const { bool skip = false; if (node->attributes()->hasAttr("schedule.cond")) { skip = true; @@ -471,6 +471,8 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE } } + // Node can be run the earliest just after all its conditions are computed. + // A condition act like an additionnal parent. if (node->attributes()->hasAttr("schedule.cond")) { auto attr = node->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond"); @@ -514,7 +516,14 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE } // Node can be run the latest just before its earliest child is run + bool condition = false; // check if node can be a condition for conditional nodes for (const auto& child : node->getChildren()) { + if (child->type() == "Select" && node == child->input(0).first) { + // If the node child is a Select operator, it may be a condition to + // some conditional nodes (if node is the first input of Select). + condition = true; + } + // Find child node earliest scheduled position const auto it = std::find_if(schedule.begin() + elt + 1, schedule.end(), [child](const auto& v) { return (v->node == child); }); @@ -523,23 +532,28 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE late = std::min(late, schedule[step]->late - 1); schedule[step]->laterThan.push_back(schedule[elt]); } + } - if (child->type() == "Select") { - for (const auto& condNode : mGraphView->getNodes()) { - if (condNode->attributes()->hasAttr("schedule.cond")) { - auto attr = condNode->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond"); - - for (const auto& cond : attr) { - const auto& select = cond.first; - - if (node == select->input(0).first) { - const auto itElt = std::find_if(schedule.begin() + elt + 1, schedule.end(), - [condNode](const auto& v) { return (v->node == condNode); }); - if (itElt != schedule.end()) { - const std::size_t step = std::distance(schedule.begin(), itElt); - late = std::min(late, schedule[step]->late - 1); - schedule[step]->laterThan.push_back(schedule[elt]); - } + // When node is a condition to conditional nodes, it acts like a parent + // to them. Therefore, the conditional nodes should be considered as + // childs to this node. + if (condition) { + for (const auto& condNode : mGraphView->getNodes()) { + if (condNode->attributes()->hasAttr("schedule.cond")) { + auto attr = condNode->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond"); + + for (const auto& cond : attr) { + const auto& select = cond.first; + + // Check if node is a condition to this conditional node + if (node == select->input(0).first) { + // If so, the conditional node act like a child + const auto it = std::find_if(schedule.begin() + elt + 1, schedule.end(), + [condNode](const auto& v) { return (v->node == condNode); }); + if (it != schedule.end()) { + const std::size_t step = std::distance(schedule.begin(), it); + late = std::min(late, schedule[step]->late - 1); + schedule[step]->laterThan.push_back(schedule[elt]); } } } diff --git a/src/scheduler/SequentialScheduler.cpp b/src/scheduler/SequentialScheduler.cpp index 5f6bb6c074df9c9a03b4e79cd95f5bba1f031196..2b1956d790f960124db5c034fa8b4fb790af1d54 100644 --- a/src/scheduler/SequentialScheduler.cpp +++ b/src/scheduler/SequentialScheduler.cpp @@ -59,7 +59,7 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); for (const auto& runnable : staticSchedule) { - const bool skip = !isNodeCondValid(runnable->node); + const bool skip = !isConditionalNodeRequired(runnable->node); Log::debug("run: {}{}", namePtrTable.at(runnable->node), (skip) ? " -- skipped" : ""); if (!skip) {