Skip to content
Snippets Groups Projects
Commit c57db282 authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'fix-scheduler-debug' into 'dev'

[Scheduler] fix internal issue when scheduling producer nodes

See merge request !224
parents c815b549 1b2553b3
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!224[Scheduler] fix internal issue when scheduling producer nodes
Pipeline #57533 passed
......@@ -77,6 +77,9 @@ std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::genera
// we should always consume available data first. This is ensured
// by setting the consumers list to the output nodes and then recursively
// find the dependencies.
// The initial list may contain producer nodes, in which case
// getPriorProducersConsumers() at step 2 will have moved it in
// the requiredProducers list.
std::set<std::shared_ptr<Node>> consumers = mGraphView->outputNodes();
std::set<std::shared_ptr<Node>> producers;
......@@ -300,19 +303,23 @@ std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::genera
void Aidge::Scheduler::summarizeConsumerState(const std::shared_ptr<Aidge::Node>& consumer, const std::string& nodeName) const {
Log::debug("\t- consumer: {}", fmt::styled(nodeName, fg(fmt::color::orange)));
std::string crLog = "\t\tC/R:\t";
for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) {
crLog += fmt::format("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId),
consumer->getOperator()->getNbRequiredData(inId));
if (consumer->nbInputs() > 0) {
for (IOIndex_t inId = 0; inId < consumer->nbInputs() - 1; ++inId) {
crLog += fmt::format("{}/{}\n\t\t\t", consumer->getOperator()->getNbConsumedData(inId),
consumer->getOperator()->getNbRequiredData(inId));
}
crLog += fmt::format("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1),
consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1));
}
crLog += fmt::format("{}/{}", consumer->getOperator()->getNbConsumedData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1),
consumer->getOperator()->getNbRequiredData(static_cast<IOIndex_t>(consumer->nbInputs()) - 1));
Log::debug("{}", crLog);
std::string pLog = "\t\tP:\t";
for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) {
pLog += fmt::format("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId));
if (consumer->nbOutputs() > 0) {
for (IOIndex_t outId = 0; outId < consumer->nbOutputs() - 1; ++outId) {
pLog += fmt::format("{}\n\t\t\t", consumer->getOperator()->getNbProducedData(outId));
}
pLog += fmt::format("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1));
}
pLog += fmt::format("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1));
Log::debug("{}", pLog);
}
......@@ -733,7 +740,9 @@ Aidge::Scheduler::getPriorProducersConsumers(const std::shared_ptr<Node>& node)
}
prior.isPrior = true;
if (prior.priorConsumers.empty()) {
if (node->type() == Producer_Op::Type) {
prior.requiredProducers.insert(node);
} else if (prior.priorConsumers.empty()) {
prior.priorConsumers.insert(node);
}
mPriorCache.insert(std::make_pair(node, prior));
......
......@@ -26,6 +26,8 @@
#include "aidge/graph/Testing.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/Identity.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
namespace Aidge {
......@@ -134,4 +136,58 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") {
fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests);
}
TEST_CASE("someScheduling", "[Scheduler][someUseCases]") {
SECTION("Identity Graph") {
auto data1 = Producer({1}, "data1");
auto identity = Identity("id");
auto g = std::make_shared<GraphView>("TestGraph");
data1->addChild(identity);
g->add({data1, identity});
auto scheduler = SequentialScheduler(g);
scheduler.generateScheduling();
const auto sch = scheduler.getStaticScheduling();
const auto nodes = g->getNodes();
REQUIRE(sch.size() == nodes.size());
REQUIRE(sch[0] == data1);
REQUIRE(sch[1] == identity);
}
SECTION("Producer Graph") {
auto data1 = Producer({1}, "data1");
auto g = std::make_shared<GraphView>("TestGraph");
g->add({data1});
auto scheduler = SequentialScheduler(g);
scheduler.generateScheduling();
const auto sch = scheduler.getStaticScheduling();
const auto nodes = g->getNodes();
REQUIRE(sch.size() == nodes.size());
REQUIRE(sch[0] == data1);
}
SECTION("Generic producer Graph") {
auto gen1 = GenericOperator("Prod", 0, 0, 1, "gen1");
auto g = std::make_shared<GraphView>("TestGraph");
g->add({gen1});
auto scheduler = SequentialScheduler(g);
scheduler.generateScheduling();
const auto sch = scheduler.getStaticScheduling();
const auto nodes = g->getNodes();
REQUIRE(sch.size() == nodes.size());
REQUIRE(sch[0] == gen1);
}
SECTION("No output Graph") {
auto dead1 = GenericOperator("Dead", 1, 0, 0, "dead");
auto g = std::make_shared<GraphView>("TestGraph");
g->add({dead1});
auto scheduler = SequentialScheduler(g);
scheduler.generateScheduling();
const auto sch = scheduler.getStaticScheduling();
const auto nodes = g->getNodes();
REQUIRE(nodes.size() == 1);
REQUIRE(sch.size() == 0);
}
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment