Skip to content
Snippets Groups Projects
Commit c3318428 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Working concept of with tagConditionalNodes()

parent 2512f855
No related branches found
No related tags found
1 merge request!332Add selection mechanism in graph
......@@ -127,6 +127,8 @@ public:
virtual ~Scheduler();
public:
void tagConditionalNodes();
/**
* @brief Get the static scheduling order of nodes.
* @param step The step of the static schedule to retrieve (default is 0).
......
......@@ -165,7 +165,10 @@ public:
else {
const auto ns = name.substr(0, dot);
const auto nsName = name.substr(dot + 1);
future_std::any_cast<DynamicAttributes&>(mAttrs.at(ns)).delAttr(nsName);
auto it = mAttrs.find(ns);
if (it != mAttrs.end()) {
future_std::any_cast<DynamicAttributes&>(it->second).delAttr(nsName);
}
}
}
......
......@@ -56,6 +56,54 @@ void Aidge::Scheduler::generateScheduling() {
mStaticSchedule.push_back(schedule);
}
void Aidge::Scheduler::tagConditionalNodes() {
// Get a list of selectors
std::vector<NodePtr> selectors;
for (const auto& node : mGraphView->getNodes()) {
if (node->type() == "Select") {
selectors.push_back(node);
}
node->attributes()->delAttr("schedule.cond");
}
std::function<void(NodePtr, std::set<NodePtr>&)> recInBranch = [&recInBranch](NodePtr node, std::set<NodePtr>& branchNodes) {
bool inBranch = true;
for (const auto& child : node->getChildren()) {
if (branchNodes.find(child) == branchNodes.end()) {
inBranch = false;
break;
}
}
if (inBranch) {
branchNodes.insert(node);
for (const auto& parent : node->getParents()) {
recInBranch(parent, branchNodes);
}
}
};
// For each selector, tag nodes
for (const auto& select : selectors) {
for (size_t branch = 0; branch < select->getParents().size() - 1; ++branch) {
std::set<NodePtr> branchNodes;
branchNodes.insert(select);
recInBranch(select->getParent(branch + 1), branchNodes);
branchNodes.erase(select);
for (const auto& node : branchNodes) {
std::set<std::pair<NodePtr, size_t>> attr;
if (node->attributes()->hasAttr("schedule.cond")) {
attr = node->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond");
}
attr.insert({select, branch});
node->attributes()->setAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond", attr);
}
}
}
}
std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::generateBaseScheduling() const {
// 0) setup useful variables
......@@ -182,6 +230,22 @@ std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::genera
}
}
if (consumer->attributes()->hasAttr("schedule.cond")) {
auto attr = consumer->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond");
for (const auto& cond : attr) {
const auto& select = cond.first;
AvailableDataStatus status;
if ((select->getOperator()->getNbConsumedData(0) + select->getOperator()->getNbRequiredData(0)) >
getNbAvailableData(select, 0, status))
{
isRunnable = false;
break;
}
}
}
if (isRunnable) {
runnableConsumers.insert(consumer);
}
......@@ -386,6 +450,23 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE
}
}
if (node->attributes()->hasAttr("schedule.cond")) {
auto attr = node->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond");
for (const auto& cond : attr) {
const auto& select = cond.first;
const auto& parent = select->input(0).first;
const auto it = std::find_if(schedule.rend() - elt, schedule.rend(),
[parent](const auto& v) { return (v->node == parent); });
if (it != schedule.rend()) {
const std::size_t step = std::distance(schedule.begin(), it.base()) - 1;
early = std::max(early, schedule[step]->early + 1);
schedule[step]->earlierThan.push_back(schedule[elt]);
}
}
}
latest = std::max(latest, early);
schedule[elt]->early = early;
}
......@@ -421,8 +502,32 @@ void Aidge::Scheduler::generateEarlyLateScheduling(std::vector<StaticSchedulingE
late = std::min(late, schedule[step]->late - 1);
schedule[step]->laterThan.push_back(schedule[elt]);
}
if (child->type() == "Select") {
for (const auto& condNode : mGraphView->getNodes()) {
if (condNode->attributes()->hasAttr("schedule.cond")) {
auto attr = condNode->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond");
for (const auto& cond : attr) {
const auto& select = cond.first;
if (node == select->input(0).first) {
const auto it = std::find_if(schedule.begin() + elt + 1, schedule.end(),
[condNode](const auto& v) { return (v->node == condNode); });
if (it != schedule.end()) {
const std::size_t step = std::distance(schedule.begin(), it);
late = std::min(late, schedule[step]->late - 1);
schedule[step]->laterThan.push_back(schedule[elt]);
}
}
}
}
}
}
}
// TODO: ADD HERE SCHEDULE COND
schedule[elt]->late = late;
}
}
......@@ -1148,6 +1253,30 @@ Aidge::Scheduler::getPriorProducersConsumers(const std::shared_ptr<Node>& node)
++inputIdx;
}
if (node->attributes()->hasAttr("schedule.cond")) {
auto attr = node->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond");
for (const auto& cond : attr) {
const auto& select = cond.first;
const auto& parent = select->input(0);
if ((select->getOperator()->getNbConsumedData(0) + select->getOperator()->getNbRequiredData(0)) >
parent.first->getOperator()->getNbProducedData(parent.second))
{
const auto& parentPrior = getPriorProducersConsumers(parent.first);
if (!parentPrior.isPrior) {
// only happens in case of cyclic graphs
return PriorProducersConsumers(); // not scheduled
}
else {
prior.requiredProducers.insert(parentPrior.requiredProducers.cbegin(), parentPrior.requiredProducers.cend());
prior.priorConsumers.insert(parentPrior.priorConsumers.cbegin(), parentPrior.priorConsumers.cend());
}
}
}
}
prior.isPrior = true;
if (node->type() == Producer_Op::Type) {
prior.requiredProducers.insert(node);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment