From d34c46219694ac7084073c097a2a6de04b223af9 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Sun, 16 Mar 2025 11:30:52 +0100 Subject: [PATCH] Revert "Removed unrelated change" This reverts commit ea9a0a70e58900bbc54aeded143e7a37f62bcf92. --- unit_tests/scheduler/Test_Scheduler.cpp | 53 +++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index 54e57ec4..be87e8ac 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -17,6 +17,7 @@ #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/graph/OpArgs.hpp" +#include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/Memorize.hpp" #include "aidge/operator/Pop.hpp" #include "aidge/operator/Stack.hpp" @@ -28,6 +29,7 @@ #include "aidge/operator/MetaOperator.hpp" #include "aidge/scheduler/SequentialScheduler.hpp" #include "aidge/scheduler/ParallelScheduler.hpp" +#include "aidge/graph/Testing.hpp" #include "aidge/backend/cpu/operator/FCImpl.hpp" #include "aidge/backend/cpu/operator/ConvImpl.hpp" @@ -520,6 +522,57 @@ TEST_CASE("[cpu/scheduler] Accumulate", "[scheduler]") { REQUIRE(*output == *expectedOutput); } +TEST_CASE("[cpu/scheduler] Branch", "[scheduler]") { + std::shared_ptr<Tensor> in = std::make_shared<Tensor>( + Array2D<float, 2, 3>{{{1, 2, 3}, {4, 5, 6}}}); + + std::shared_ptr<GraphView> g = Sequential({ + Producer(in, "input"), + Parallel({ + Sequential({ + GenericOperator("b0_op1", {InputCategory::Data}, 1), + GenericOperator("b0_op2", {InputCategory::Data}, 1), + GenericOperator("b0_op3", {InputCategory::Data}, 1), + GenericOperator("b0_op4", {InputCategory::Data}, 1), + GenericOperator("b0_op5", {InputCategory::Data}, 1) + }), + Sequential({ + GenericOperator("b1_op1", {InputCategory::Data}, 1), + GenericOperator("b1_op2", {InputCategory::Data}, 1), + GenericOperator("b1_op3", {InputCategory::Data}, 1) + }), + Sequential({ + GenericOperator("b2_op1", {InputCategory::Data}, 1) + }) + }), + GenericOperator("op1", {InputCategory::Data, InputCategory::Data, InputCategory::Data}, 1), + GenericOperator("op2", {InputCategory::Data}, 1), + GenericOperator("op3", {InputCategory::Data}, 1) + }); + + g->save("branch_forwarded"); + + auto scheduler = SequentialScheduler(g); + scheduler.generateScheduling(); + scheduler.saveStaticSchedulingDiagram("branch_scheduling"); + + // Default scheduling order is not necessarily determinist, but is garanteed to be correct in every case. + // This behavior might change in the future. + auto seqSchedule = scheduler.Scheduler::getSequentialStaticScheduling(0, Scheduler::SchedulingPolicy::Default); + fmt::println("seqSchedule = {}", seqSchedule); + + scheduler.tagForkBranches(); + g->save("branch_forwarded_tag"); + + seqSchedule = scheduler.Scheduler::getSequentialStaticScheduling(0, Scheduler::SchedulingPolicy::ShortestBranchFirst); + REQUIRE(nodePtrTo(seqSchedule, nodePtrToType) == std::vector<std::string>{ + "Producer", "b2_op1", "b1_op1", "b1_op2", "b1_op3", "b0_op1", "b0_op2", "b0_op3", "b0_op4", "b0_op5", "op1", "op2", "op3"}); + + seqSchedule = scheduler.Scheduler::getSequentialStaticScheduling(0, Scheduler::SchedulingPolicy::LonguestBranchFirst); + REQUIRE(nodePtrTo(seqSchedule, nodePtrToType) == std::vector<std::string>{ + "Producer", "b0_op1", "b0_op2", "b0_op3", "b0_op4", "b0_op5", "b1_op1", "b1_op2", "b1_op3", "b2_op1", "op1", "op2", "op3"}); +} + #ifdef WITH_OPENSSL TEST_CASE("[cpu/scheduler] Select", "[scheduler]") { std::shared_ptr<Tensor> in = std::make_shared<Tensor>( -- GitLab