From 1b2553b3c390a18cf25364ab6111ebf4ffde7619 Mon Sep 17 00:00:00 2001 From: Christophe Guillon <christophe.guillon@inria.fr> Date: Mon, 21 Oct 2024 17:20:10 +0200 Subject: [PATCH] [Scheduler] Fix scheduling of producer node in graph outputs Handle case where producer nodes are also graph outputs, in this case the initial list of consumers may contain producer nodes. They are now scheduled as requiredProducers instead of runnableConsumers. It actually does not change the effective scheduling, though it handles more uniformely producers nodes (outputs of the graph or not). Added some test cases in Test_scheduler.cpp. --- src/scheduler/Scheduler.cpp | 7 +++- unit_tests/scheduler/Test_Scheduler.cpp | 56 +++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 88aebb3ba..7af3c62c5 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -77,6 +77,9 @@ std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::genera // we should always consume available data first. This is ensured // by setting the consumers list to the output nodes and then recursively // find the dependencies. + // The initial list may contain producer nodes, in which case + // getPriorProducersConsumers() at step 2 will have moved it in + // the requiredProducers list. std::set<std::shared_ptr<Node>> consumers = mGraphView->outputNodes(); std::set<std::shared_ptr<Node>> producers; @@ -737,7 +740,9 @@ Aidge::Scheduler::getPriorProducersConsumers(const std::shared_ptr<Node>& node) } prior.isPrior = true; - if (prior.priorConsumers.empty()) { + if (node->type() == Producer_Op::Type) { + prior.requiredProducers.insert(node); + } else if (prior.priorConsumers.empty()) { prior.priorConsumers.insert(node); } mPriorCache.insert(std::make_pair(node, prior)); diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index 3c3026ff0..ec850d281 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -26,6 +26,8 @@ #include "aidge/graph/Testing.hpp" #include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/Producer.hpp" +#include "aidge/operator/Identity.hpp" +#include "aidge/operator/GenericOperator.hpp" #include "aidge/scheduler/SequentialScheduler.hpp" namespace Aidge { @@ -134,4 +136,58 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests); } +TEST_CASE("someScheduling", "[Scheduler][someUseCases]") { + + SECTION("Identity Graph") { + auto data1 = Producer({1}, "data1"); + auto identity = Identity("id"); + auto g = std::make_shared<GraphView>("TestGraph"); + data1->addChild(identity); + g->add({data1, identity}); + auto scheduler = SequentialScheduler(g); + scheduler.generateScheduling(); + const auto sch = scheduler.getStaticScheduling(); + const auto nodes = g->getNodes(); + REQUIRE(sch.size() == nodes.size()); + REQUIRE(sch[0] == data1); + REQUIRE(sch[1] == identity); + } + + SECTION("Producer Graph") { + auto data1 = Producer({1}, "data1"); + auto g = std::make_shared<GraphView>("TestGraph"); + g->add({data1}); + auto scheduler = SequentialScheduler(g); + scheduler.generateScheduling(); + const auto sch = scheduler.getStaticScheduling(); + const auto nodes = g->getNodes(); + REQUIRE(sch.size() == nodes.size()); + REQUIRE(sch[0] == data1); + } + + SECTION("Generic producer Graph") { + auto gen1 = GenericOperator("Prod", 0, 0, 1, "gen1"); + auto g = std::make_shared<GraphView>("TestGraph"); + g->add({gen1}); + auto scheduler = SequentialScheduler(g); + scheduler.generateScheduling(); + const auto sch = scheduler.getStaticScheduling(); + const auto nodes = g->getNodes(); + REQUIRE(sch.size() == nodes.size()); + REQUIRE(sch[0] == gen1); + } + + SECTION("No output Graph") { + auto dead1 = GenericOperator("Dead", 1, 0, 0, "dead"); + auto g = std::make_shared<GraphView>("TestGraph"); + g->add({dead1}); + auto scheduler = SequentialScheduler(g); + scheduler.generateScheduling(); + const auto sch = scheduler.getStaticScheduling(); + const auto nodes = g->getNodes(); + REQUIRE(nodes.size() == 1); + REQUIRE(sch.size() == 0); + } +} + } // namespace Aidge -- GitLab