Skip to content
Snippets Groups Projects
Commit d34c4621 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Revert "Removed unrelated change"

This reverts commit ea9a0a70.
parent ea9a0a70
No related branches found
No related tags found
1 merge request!149Fix mean computation for integers
Pipeline #67888 passed
......@@ -17,6 +17,7 @@
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Memorize.hpp"
#include "aidge/operator/Pop.hpp"
#include "aidge/operator/Stack.hpp"
......@@ -28,6 +29,7 @@
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/scheduler/ParallelScheduler.hpp"
#include "aidge/graph/Testing.hpp"
#include "aidge/backend/cpu/operator/FCImpl.hpp"
#include "aidge/backend/cpu/operator/ConvImpl.hpp"
......@@ -520,6 +522,57 @@ TEST_CASE("[cpu/scheduler] Accumulate", "[scheduler]") {
REQUIRE(*output == *expectedOutput);
}
TEST_CASE("[cpu/scheduler] Branch", "[scheduler]") {
std::shared_ptr<Tensor> in = std::make_shared<Tensor>(
Array2D<float, 2, 3>{{{1, 2, 3}, {4, 5, 6}}});
std::shared_ptr<GraphView> g = Sequential({
Producer(in, "input"),
Parallel({
Sequential({
GenericOperator("b0_op1", {InputCategory::Data}, 1),
GenericOperator("b0_op2", {InputCategory::Data}, 1),
GenericOperator("b0_op3", {InputCategory::Data}, 1),
GenericOperator("b0_op4", {InputCategory::Data}, 1),
GenericOperator("b0_op5", {InputCategory::Data}, 1)
}),
Sequential({
GenericOperator("b1_op1", {InputCategory::Data}, 1),
GenericOperator("b1_op2", {InputCategory::Data}, 1),
GenericOperator("b1_op3", {InputCategory::Data}, 1)
}),
Sequential({
GenericOperator("b2_op1", {InputCategory::Data}, 1)
})
}),
GenericOperator("op1", {InputCategory::Data, InputCategory::Data, InputCategory::Data}, 1),
GenericOperator("op2", {InputCategory::Data}, 1),
GenericOperator("op3", {InputCategory::Data}, 1)
});
g->save("branch_forwarded");
auto scheduler = SequentialScheduler(g);
scheduler.generateScheduling();
scheduler.saveStaticSchedulingDiagram("branch_scheduling");
// Default scheduling order is not necessarily determinist, but is garanteed to be correct in every case.
// This behavior might change in the future.
auto seqSchedule = scheduler.Scheduler::getSequentialStaticScheduling(0, Scheduler::SchedulingPolicy::Default);
fmt::println("seqSchedule = {}", seqSchedule);
scheduler.tagForkBranches();
g->save("branch_forwarded_tag");
seqSchedule = scheduler.Scheduler::getSequentialStaticScheduling(0, Scheduler::SchedulingPolicy::ShortestBranchFirst);
REQUIRE(nodePtrTo(seqSchedule, nodePtrToType) == std::vector<std::string>{
"Producer", "b2_op1", "b1_op1", "b1_op2", "b1_op3", "b0_op1", "b0_op2", "b0_op3", "b0_op4", "b0_op5", "op1", "op2", "op3"});
seqSchedule = scheduler.Scheduler::getSequentialStaticScheduling(0, Scheduler::SchedulingPolicy::LonguestBranchFirst);
REQUIRE(nodePtrTo(seqSchedule, nodePtrToType) == std::vector<std::string>{
"Producer", "b0_op1", "b0_op2", "b0_op3", "b0_op4", "b0_op5", "b1_op1", "b1_op2", "b1_op3", "b2_op1", "op1", "op2", "op3"});
}
#ifdef WITH_OPENSSL
TEST_CASE("[cpu/scheduler] Select", "[scheduler]") {
std::shared_ptr<Tensor> in = std::make_shared<Tensor>(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment