Skip to content
Snippets Groups Projects
Commit 170d5931 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Working concept of with tagConditionalNodes()

parent 602fe5a5
No related branches found
No related tags found
No related merge requests found
...@@ -127,6 +127,8 @@ public: ...@@ -127,6 +127,8 @@ public:
virtual ~Scheduler(); virtual ~Scheduler();
public: public:
void tagConditionalNodes();
/** /**
* @brief Get the static scheduling order of nodes. * @brief Get the static scheduling order of nodes.
* @param step The step of the static schedule to retrieve (default is 0). * @param step The step of the static schedule to retrieve (default is 0).
......
...@@ -165,7 +165,10 @@ public: ...@@ -165,7 +165,10 @@ public:
else { else {
const auto ns = name.substr(0, dot); const auto ns = name.substr(0, dot);
const auto nsName = name.substr(dot + 1); const auto nsName = name.substr(dot + 1);
future_std::any_cast<DynamicAttributes&>(mAttrs.at(ns)).delAttr(nsName); auto it = mAttrs.find(ns);
if (it != mAttrs.end()) {
future_std::any_cast<DynamicAttributes&>(it->second).delAttr(nsName);
}
} }
} }
......
...@@ -56,6 +56,54 @@ void Aidge::Scheduler::generateScheduling() { ...@@ -56,6 +56,54 @@ void Aidge::Scheduler::generateScheduling() {
mStaticSchedule.push_back(schedule); mStaticSchedule.push_back(schedule);
} }
void Aidge::Scheduler::tagConditionalNodes() {
// Get a list of selectors
std::vector<NodePtr> selectors;
for (const auto& node : mGraphView->getNodes()) {
if (node->type() == "Select") {
selectors.push_back(node);
}
node->attributes()->delAttr("schedule.cond");
}
std::function<void(NodePtr, std::set<NodePtr>&)> recInBranch = [&recInBranch](NodePtr node, std::set<NodePtr>& branchNodes) {
bool inBranch = true;
for (const auto& child : node->getChildren()) {
if (branchNodes.find(child) == branchNodes.end()) {
inBranch = false;
break;
}
}
if (inBranch) {
branchNodes.insert(node);
for (const auto& parent : node->getParents()) {
recInBranch(parent, branchNodes);
}
}
};
// For each selector, tag nodes
for (const auto& select : selectors) {
for (size_t branch = 0; branch < select->getParents().size() - 1; ++branch) {
std::set<NodePtr> branchNodes;
branchNodes.insert(select);
recInBranch(select->getParent(branch + 1), branchNodes);
branchNodes.erase(select);
for (const auto& node : branchNodes) {
std::set<std::pair<NodePtr, size_t>> attr;
if (node->attributes()->hasAttr("schedule.cond")) {
attr = node->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond");
}
attr.insert({select, branch});
node->attributes()->setAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond", attr);
}
}
}
}
std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::generateBaseScheduling() const { std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::generateBaseScheduling() const {
// 0) setup useful variables // 0) setup useful variables
...@@ -182,6 +230,22 @@ std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::genera ...@@ -182,6 +230,22 @@ std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::genera
} }
} }
if (consumer->attributes()->hasAttr("schedule.cond")) {
auto attr = consumer->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond");
for (const auto& cond : attr) {
const auto& select = cond.first;
AvailableDataStatus status;
if ((select->getOperator()->getNbConsumedData(0) + select->getOperator()->getNbRequiredData(0)) >
getNbAvailableData(select, 0, status))
{
isRunnable = false;
break;
}
}
}
if (isRunnable) { if (isRunnable) {
runnableConsumers.insert(consumer); runnableConsumers.insert(consumer);
} }
...@@ -386,6 +450,23 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE ...@@ -386,6 +450,23 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE
} }
} }
if (node->attributes()->hasAttr("schedule.cond")) {
auto attr = node->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond");
for (const auto& cond : attr) {
const auto& select = cond.first;
const auto& parent = select->input(0).first;
const auto it = std::find_if(schedule.rend() - elt, schedule.rend(),
[parent](const auto& v) { return (v->node == parent); });
if (it != schedule.rend()) {
const std::size_t step = std::distance(schedule.begin(), it.base()) - 1;
early = std::max(early, schedule[step]->early + 1);
schedule[step]->earlierThan.push_back(schedule[elt]);
}
}
}
latest = std::max(latest, early); latest = std::max(latest, early);
schedule[elt]->early = early; schedule[elt]->early = early;
} }
...@@ -421,8 +502,32 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE ...@@ -421,8 +502,32 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE
late = std::min(late, schedule[step]->late - 1); late = std::min(late, schedule[step]->late - 1);
schedule[step]->laterThan.push_back(schedule[elt]); schedule[step]->laterThan.push_back(schedule[elt]);
} }
if (child->type() == "Select") {
for (const auto& condNode : mGraphView->getNodes()) {
if (condNode->attributes()->hasAttr("schedule.cond")) {
auto attr = condNode->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond");
for (const auto& cond : attr) {
const auto& select = cond.first;
if (node == select->input(0).first) {
const auto it = std::find_if(schedule.begin() + elt + 1, schedule.end(),
[condNode](const auto& v) { return (v->node == condNode); });
if (it != schedule.end()) {
const std::size_t step = std::distance(schedule.begin(), it);
late = std::min(late, schedule[step]->late - 1);
schedule[step]->laterThan.push_back(schedule[elt]);
}
}
}
}
}
}
} }
// TODO: ADD HERE SCHEDULE COND
schedule[elt]->late = late; schedule[elt]->late = late;
} }
} }
...@@ -1148,6 +1253,30 @@ Aidge::Scheduler::getPriorProducersConsumers(const std::shared_ptr<Node>& node) ...@@ -1148,6 +1253,30 @@ Aidge::Scheduler::getPriorProducersConsumers(const std::shared_ptr<Node>& node)
++inputIdx; ++inputIdx;
} }
if (node->attributes()->hasAttr("schedule.cond")) {
auto attr = node->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond");
for (const auto& cond : attr) {
const auto& select = cond.first;
const auto& parent = select->input(0);
if ((select->getOperator()->getNbConsumedData(0) + select->getOperator()->getNbRequiredData(0)) >
parent.first->getOperator()->getNbProducedData(parent.second))
{
const auto& parentPrior = getPriorProducersConsumers(parent.first);
if (!parentPrior.isPrior) {
// only happens in case of cyclic graphs
return PriorProducersConsumers(); // not scheduled
}
else {
prior.requiredProducers.insert(parentPrior.requiredProducers.cbegin(), parentPrior.requiredProducers.cend());
prior.priorConsumers.insert(parentPrior.priorConsumers.cbegin(), parentPrior.priorConsumers.cend());
}
}
}
}
prior.isPrior = true; prior.isPrior = true;
if (node->type() == Producer_Op::Type) { if (node->type() == Producer_Op::Type) {
prior.requiredProducers.insert(node); prior.requiredProducers.insert(node);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment