diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index ed9db47b336e9bde86cc761db8524c88dd818379..a164f6c76663e7227533bf5615335e9d228dd15a 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -127,7 +127,22 @@ public: virtual ~Scheduler(); public: - void tagConditionalNodes(); + /** + * @brief Add schedule.cond attribute to conditional nodes. + * The schedule.cond attribute is a `std::set<std::pair<NodePtr, size_t>>`, + * where the first element is the Select node and the second element, the + * Select input index. + */ + void tagConditionalNodes() const; + + /** + * @brief Check if the node condition is valid. + * + * @param node Node to check the condition. + * @return true If the node condition is valid, meaning it has to be executed. + * @return false If the node condition is not valid, meaning it can be skipped. + */ + bool isNodeCondValid(NodePtr node) const; /** * @brief Get the static scheduling order of nodes. diff --git a/src/scheduler/ParallelScheduler.cpp b/src/scheduler/ParallelScheduler.cpp index 2a44dd49f961bdcdf965a33d2ffe91f3ed8ae352..8e53254d71cc84e1d92932d0fe1e0d40ae999232 100644 --- a/src/scheduler/ParallelScheduler.cpp +++ b/src/scheduler/ParallelScheduler.cpp @@ -88,6 +88,11 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std:: // Add the critical node to the thread pool queue, to be run ASAP finished[runnable] = false; pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() { + if (!isNodeCondValid(node)) { + finished = true; + return; + } + const auto tStart = std::chrono::high_resolution_clock::now(); node->forward(); const auto tEnd = std::chrono::high_resolution_clock::now(); @@ -144,6 +149,11 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std:: // Add the node to the thread pool queue, to be run ASAP finished[runnable] = false; pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() { + if (!isNodeCondValid(node)) { + finished = true; + return; + } + const auto tStart = std::chrono::high_resolution_clock::now(); node->forward(); const auto tEnd = std::chrono::high_resolution_clock::now(); diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 3f59f2fdc6117caa60f1747e9d824a91edd5845c..bc0b19cfe317ebab0857fc25164d23a103dcdfb7 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -56,7 +56,7 @@ void Aidge::Scheduler::generateScheduling() { mStaticSchedule.push_back(schedule); } -void Aidge::Scheduler::tagConditionalNodes() { +void Aidge::Scheduler::tagConditionalNodes() const { // Get a list of selectors std::vector<NodePtr> selectors; for (const auto& node : mGraphView->getNodes()) { @@ -104,6 +104,27 @@ void Aidge::Scheduler::tagConditionalNodes() { } } +bool Aidge::Scheduler::isNodeCondValid(NodePtr node) const { + bool skip = false; + if (node->attributes()->hasAttr("schedule.cond")) { + skip = true; + + auto attr = node->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond"); + for (const auto& cond : attr) { + const auto& selectNode = cond.first; + const auto selectOp = std::static_pointer_cast<OperatorTensor>(selectNode->getOperator()); + + std::shared_ptr<Tensor> selectFallback; + const auto& select = selectOp->getInput(0)->refCastFrom(selectFallback, DataType::Int32, "cpu"); + const auto selectVal = select.get<int32_t>(0); + + skip &= (selectVal != static_cast<int32_t>(cond.second)); + } + } + + return !skip; +} + std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::generateBaseScheduling() const { // 0) setup useful variables @@ -512,10 +533,10 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE const auto& select = cond.first; if (node == select->input(0).first) { - const auto it = std::find_if(schedule.begin() + elt + 1, schedule.end(), + const auto itElt = std::find_if(schedule.begin() + elt + 1, schedule.end(), [condNode](const auto& v) { return (v->node == condNode); }); - if (it != schedule.end()) { - const std::size_t step = std::distance(schedule.begin(), it); + if (itElt != schedule.end()) { + const std::size_t step = std::distance(schedule.begin(), itElt); late = std::min(late, schedule[step]->late - 1); schedule[step]->laterThan.push_back(schedule[elt]); } @@ -526,8 +547,6 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE } } - // TODO: ADD HERE SCHEDULE COND - schedule[elt]->late = late; } } diff --git a/src/scheduler/SequentialScheduler.cpp b/src/scheduler/SequentialScheduler.cpp index 4e6e91f51878ffce7d910a361c9a6e8fff9cb835..5f6bb6c074df9c9a03b4e79cd95f5bba1f031196 100644 --- a/src/scheduler/SequentialScheduler.cpp +++ b/src/scheduler/SequentialScheduler.cpp @@ -59,12 +59,15 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); for (const auto& runnable : staticSchedule) { - Log::debug("run: {}", namePtrTable.at(runnable->node)); - - const auto tStart = std::chrono::high_resolution_clock::now(); - runnable->node->forward(); - const auto tEnd = std::chrono::high_resolution_clock::now(); - mScheduling.push_back(SchedulingElement(runnable->node, tStart, tEnd)); + const bool skip = !isNodeCondValid(runnable->node); + Log::debug("run: {}{}", namePtrTable.at(runnable->node), (skip) ? " -- skipped" : ""); + + if (!skip) { + const auto tStart = std::chrono::high_resolution_clock::now(); + runnable->node->forward(); + const auto tEnd = std::chrono::high_resolution_clock::now(); + mScheduling.push_back(SchedulingElement(runnable->node, tStart, tEnd)); + } } ++mStaticScheduleStep;