Skip to content
Snippets Groups Projects
Commit a1b31cdc authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added conditional execution

parent c3318428
No related branches found
No related tags found
3 merge requests!414Update version 0.5.1 -> 0.6.0,!408[Add] Dropout Operator,!332Add selection mechanism in graph
......@@ -127,7 +127,22 @@ public:
virtual ~Scheduler();
public:
void tagConditionalNodes();
/**
* @brief Add schedule.cond attribute to conditional nodes.
* The schedule.cond attribute is a `std::set<std::pair<NodePtr, size_t>>`,
* where the first element is the Select node and the second element, the
* Select input index.
*/
void tagConditionalNodes() const;
/**
* @brief Check if the node condition is valid.
*
* @param node Node to check the condition.
* @return true If the node condition is valid, meaning it has to be executed.
* @return false If the node condition is not valid, meaning it can be skipped.
*/
bool isNodeCondValid(NodePtr node) const;
/**
* @brief Get the static scheduling order of nodes.
......
......@@ -88,6 +88,11 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std::
// Add the critical node to the thread pool queue, to be run ASAP
finished[runnable] = false;
pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() {
if (!isNodeCondValid(node)) {
finished = true;
return;
}
const auto tStart = std::chrono::high_resolution_clock::now();
node->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
......@@ -144,6 +149,11 @@ void Aidge::ParallelScheduler::forward(bool forwardDims, const std::vector<std::
// Add the node to the thread pool queue, to be run ASAP
finished[runnable] = false;
pool.queueJob([node = runnable->node, &finished = finished.at(runnable), &schedulingMutex, this]() {
if (!isNodeCondValid(node)) {
finished = true;
return;
}
const auto tStart = std::chrono::high_resolution_clock::now();
node->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
......
......@@ -56,7 +56,7 @@ void Aidge::Scheduler::generateScheduling() {
mStaticSchedule.push_back(schedule);
}
void Aidge::Scheduler::tagConditionalNodes() {
void Aidge::Scheduler::tagConditionalNodes() const {
// Get a list of selectors
std::vector<NodePtr> selectors;
for (const auto& node : mGraphView->getNodes()) {
......@@ -104,6 +104,27 @@ void Aidge::Scheduler::tagConditionalNodes() {
}
}
bool Aidge::Scheduler::isNodeCondValid(NodePtr node) const {
bool skip = false;
if (node->attributes()->hasAttr("schedule.cond")) {
skip = true;
auto attr = node->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond");
for (const auto& cond : attr) {
const auto& selectNode = cond.first;
const auto selectOp = std::static_pointer_cast<OperatorTensor>(selectNode->getOperator());
std::shared_ptr<Tensor> selectFallback;
const auto& select = selectOp->getInput(0)->refCastFrom(selectFallback, DataType::Int32, "cpu");
const auto selectVal = select.get<int32_t>(0);
skip &= (selectVal != static_cast<int32_t>(cond.second));
}
}
return !skip;
}
std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::generateBaseScheduling() const {
// 0) setup useful variables
......@@ -512,10 +533,10 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE
const auto& select = cond.first;
if (node == select->input(0).first) {
const auto it = std::find_if(schedule.begin() + elt + 1, schedule.end(),
const auto itElt = std::find_if(schedule.begin() + elt + 1, schedule.end(),
[condNode](const auto& v) { return (v->node == condNode); });
if (it != schedule.end()) {
const std::size_t step = std::distance(schedule.begin(), it);
if (itElt != schedule.end()) {
const std::size_t step = std::distance(schedule.begin(), itElt);
late = std::min(late, schedule[step]->late - 1);
schedule[step]->laterThan.push_back(schedule[elt]);
}
......@@ -526,8 +547,6 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE
}
}
// TODO: ADD HERE SCHEDULE COND
schedule[elt]->late = late;
}
}
......
......@@ -59,12 +59,15 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std
const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
for (const auto& runnable : staticSchedule) {
Log::debug("run: {}", namePtrTable.at(runnable->node));
const auto tStart = std::chrono::high_resolution_clock::now();
runnable->node->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
mScheduling.push_back(SchedulingElement(runnable->node, tStart, tEnd));
const bool skip = !isNodeCondValid(runnable->node);
Log::debug("run: {}{}", namePtrTable.at(runnable->node), (skip) ? " -- skipped" : "");
if (!skip) {
const auto tStart = std::chrono::high_resolution_clock::now();
runnable->node->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
mScheduling.push_back(SchedulingElement(runnable->node, tStart, tEnd));
}
}
++mStaticScheduleStep;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment