diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index 5bd86eec01c922e50e80fe837c567091ac768b1f..54e57ec44a9b803cdba0812ceebbac35c2445adf 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -569,6 +569,24 @@ TEST_CASE("[cpu/scheduler] Select", "[scheduler]") { Array2D<float, 2, 3>{{{std::sqrt(1), std::sqrt(2), std::sqrt(3)}, {std::sqrt(4), std::sqrt(5), std::sqrt(6)}}}); auto output = std::static_pointer_cast<OperatorTensor>(g->getNode("select")->getOperator())->getOutput(0); REQUIRE(*output == *expectedOutput); + + scheduler.resetScheduling(); + scheduler.tagConditionalNodes(); + + REQUIRE(g->getNode("relu")->attributes()->hasAttr("schedule.cond")); + REQUIRE(g->getNode("relu")->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond") + == std::set<std::pair<NodePtr, size_t>>{{g->getNode("select"), 0}}); + REQUIRE(g->getNode("tanh")->attributes()->hasAttr("schedule.cond")); + REQUIRE(g->getNode("tanh")->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond") + == std::set<std::pair<NodePtr, size_t>>{{g->getNode("select"), 1}}); + REQUIRE(g->getNode("sqrt")->attributes()->hasAttr("schedule.cond")); + REQUIRE(g->getNode("sqrt")->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond") + == std::set<std::pair<NodePtr, size_t>>{{g->getNode("select"), 2}}); + REQUIRE(!g->getNode("input")->attributes()->hasAttr("schedule.cond")); + + scheduler.generateScheduling(); + scheduler.saveStaticSchedulingDiagram("select_scheduling_tag"); + REQUIRE_NOTHROW(scheduler.forward(true)); } #endif } // namespace Aidge