diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 88aebb3ba64e3745e325e4752a425064a3985646..7af3c62c5d0af33b01e596ecf4c91c35ab3e17b7 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -77,6 +77,9 @@ std::vector<Aidge::Scheduler::StaticSchedulingElement*> Aidge::Scheduler::genera // we should always consume available data first. This is ensured // by setting the consumers list to the output nodes and then recursively // find the dependencies. + // The initial list may contain producer nodes, in which case + // getPriorProducersConsumers() at step 2 will have moved it in + // the requiredProducers list. std::set<std::shared_ptr<Node>> consumers = mGraphView->outputNodes(); std::set<std::shared_ptr<Node>> producers; @@ -737,7 +740,9 @@ Aidge::Scheduler::getPriorProducersConsumers(const std::shared_ptr<Node>& node) } prior.isPrior = true; - if (prior.priorConsumers.empty()) { + if (node->type() == Producer_Op::Type) { + prior.requiredProducers.insert(node); + } else if (prior.priorConsumers.empty()) { prior.priorConsumers.insert(node); } mPriorCache.insert(std::make_pair(node, prior)); diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index 3c3026ff09222f9623d886f9c4574bf23667cd9a..ec850d28109a2682bb762c89e814622de6eec3d8 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -26,6 +26,8 @@ #include "aidge/graph/Testing.hpp" #include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/Producer.hpp" +#include "aidge/operator/Identity.hpp" +#include "aidge/operator/GenericOperator.hpp" #include "aidge/scheduler/SequentialScheduler.hpp" namespace Aidge { @@ -134,4 +136,58 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { fmt::print("nbUnicity = {}/{}\n", nbUnicity, nbTests); } +TEST_CASE("someScheduling", "[Scheduler][someUseCases]") { + + SECTION("Identity Graph") { + auto data1 = Producer({1}, "data1"); + auto identity = Identity("id"); + auto g = std::make_shared<GraphView>("TestGraph"); + data1->addChild(identity); + g->add({data1, identity}); + auto scheduler = SequentialScheduler(g); + scheduler.generateScheduling(); + const auto sch = scheduler.getStaticScheduling(); + const auto nodes = g->getNodes(); + REQUIRE(sch.size() == nodes.size()); + REQUIRE(sch[0] == data1); + REQUIRE(sch[1] == identity); + } + + SECTION("Producer Graph") { + auto data1 = Producer({1}, "data1"); + auto g = std::make_shared<GraphView>("TestGraph"); + g->add({data1}); + auto scheduler = SequentialScheduler(g); + scheduler.generateScheduling(); + const auto sch = scheduler.getStaticScheduling(); + const auto nodes = g->getNodes(); + REQUIRE(sch.size() == nodes.size()); + REQUIRE(sch[0] == data1); + } + + SECTION("Generic producer Graph") { + auto gen1 = GenericOperator("Prod", 0, 0, 1, "gen1"); + auto g = std::make_shared<GraphView>("TestGraph"); + g->add({gen1}); + auto scheduler = SequentialScheduler(g); + scheduler.generateScheduling(); + const auto sch = scheduler.getStaticScheduling(); + const auto nodes = g->getNodes(); + REQUIRE(sch.size() == nodes.size()); + REQUIRE(sch[0] == gen1); + } + + SECTION("No output Graph") { + auto dead1 = GenericOperator("Dead", 1, 0, 0, "dead"); + auto g = std::make_shared<GraphView>("TestGraph"); + g->add({dead1}); + auto scheduler = SequentialScheduler(g); + scheduler.generateScheduling(); + const auto sch = scheduler.getStaticScheduling(); + const auto nodes = g->getNodes(); + REQUIRE(nodes.size() == 1); + REQUIRE(sch.size() == 0); + } +} + } // namespace Aidge