From 652eb2810fd6a0a0360f8881cd3e9b41343d8340 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Sun, 16 Feb 2025 23:46:45 +0100
Subject: [PATCH] Working concept of with tagConditionalNodes()

---
 unit_tests/scheduler/Test_Scheduler.cpp | 18 ++++++++++++++++++
 1 file changed, 18 insertions(+)

diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp
index 5bd86eec..54e57ec4 100644
--- a/unit_tests/scheduler/Test_Scheduler.cpp
+++ b/unit_tests/scheduler/Test_Scheduler.cpp
@@ -569,6 +569,24 @@ TEST_CASE("[cpu/scheduler] Select", "[scheduler]") {
         Array2D<float, 2, 3>{{{std::sqrt(1), std::sqrt(2), std::sqrt(3)}, {std::sqrt(4), std::sqrt(5), std::sqrt(6)}}});
     auto output = std::static_pointer_cast<OperatorTensor>(g->getNode("select")->getOperator())->getOutput(0);
     REQUIRE(*output == *expectedOutput);
+
+    scheduler.resetScheduling();
+    scheduler.tagConditionalNodes();
+
+    REQUIRE(g->getNode("relu")->attributes()->hasAttr("schedule.cond"));
+    REQUIRE(g->getNode("relu")->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond")
+        == std::set<std::pair<NodePtr, size_t>>{{g->getNode("select"), 0}});
+    REQUIRE(g->getNode("tanh")->attributes()->hasAttr("schedule.cond"));
+    REQUIRE(g->getNode("tanh")->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond")
+        == std::set<std::pair<NodePtr, size_t>>{{g->getNode("select"), 1}});
+    REQUIRE(g->getNode("sqrt")->attributes()->hasAttr("schedule.cond"));
+    REQUIRE(g->getNode("sqrt")->attributes()->getAttr<std::set<std::pair<NodePtr, size_t>>>("schedule.cond")
+        == std::set<std::pair<NodePtr, size_t>>{{g->getNode("select"), 2}});
+    REQUIRE(!g->getNode("input")->attributes()->hasAttr("schedule.cond"));
+
+    scheduler.generateScheduling();
+    scheduler.saveStaticSchedulingDiagram("select_scheduling_tag");
+    REQUIRE_NOTHROW(scheduler.forward(true));
 }
 #endif
 } // namespace Aidge
-- 
GitLab