From abbfe683b5a035bb4bc55d4cbe4846822c8f89f3 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Wed, 7 Feb 2024 14:28:59 +0000
Subject: [PATCH] [Upd] Scheduler unit_test for 'backward()'

---
 unit_tests/scheduler/Test_Scheduler.cpp | 27 ++++++++++++++++++++++++-
 1 file changed, 26 insertions(+), 1 deletion(-)

diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp
index 025ca8ba..dc13c6e5 100644
--- a/unit_tests/scheduler/Test_Scheduler.cpp
+++ b/unit_tests/scheduler/Test_Scheduler.cpp
@@ -20,6 +20,7 @@
 #include "aidge/scheduler/Scheduler.hpp"
 
 #include "aidge/backend/cpu.hpp"
+#include "aidge/recipies/GraphViewHelper.hpp"
 
 using namespace Aidge;
 
@@ -300,7 +301,7 @@ TEST_CASE("[cpu/scheduler] SequentialScheduler(forward)") {
 
         std::vector<std::shared_ptr<Aidge::Tensor>> dataIn = {inputTensor};
         REQUIRE_NOTHROW(scheduler.forward(true, false, dataIn));
-        
+
         scheduler.saveSchedulingDiagram("schedulingSequential");
 
         std::shared_ptr<Tensor> expectedOutput1 = std::make_shared<Tensor>(Array4D<int, 2, 3, 3, 3>{
@@ -345,4 +346,28 @@ TEST_CASE("[cpu/scheduler] SequentialScheduler(forward)") {
         bool equal4 = (*other4 == expectedOutput4);
         REQUIRE(equal4);
     }
+}
+
+TEST_CASE("[cpu/scheduler] SequentialScheduler(backward)", "[scheduler][backward]") {
+    std::shared_ptr<GraphView> gv = Sequential({ReLU(), ReLU()});
+
+    std::shared_ptr<Tensor> inputTensor =
+            std::make_shared<Tensor>(Array4D<int, 2, 1, 5, 5>{{{{{0, 1, 2, 3, 4},
+                                                                 {5, 6, 7, 8, 9},
+                                                                 {10, 11, 12, 13, 14},
+                                                                 {15, 16, 17, 18, 19},
+                                                                 {20, 21, 22, 23, 24}}},
+                                                               {{{25, 26, 27, 28, 29},
+                                                                 {30, 31, 32, 33, 34},
+                                                                 {35, 36, 37, 38, 39},
+                                                                 {40, 41, 42, 43, 44},
+                                                                 {45, 46, 47, 48, 49}}}}});
+    auto label = inputTensor;
+    // implem already set to default
+    auto myProd = Producer(inputTensor, "prod");
+    myProd -> addChild(gv);
+    gv -> compile("cpu", DataType::Float32);
+    compile_gradient(gv);
+    SequentialScheduler scheduler(gv);
+    scheduler.backward();
 }
\ No newline at end of file
-- 
GitLab