From 599643356db772cdbb68f7b985f2f9dc14efbac7 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Tue, 26 Mar 2024 15:43:57 +0000
Subject: [PATCH] Upd unit-test for SequentialScheduler::backward() member
 function

---
 unit_tests/scheduler/Test_Scheduler.cpp | 42 +++++++++++++++++--------
 1 file changed, 29 insertions(+), 13 deletions(-)

diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp
index 7321c151..1fa0e577 100644
--- a/unit_tests/scheduler/Test_Scheduler.cpp
+++ b/unit_tests/scheduler/Test_Scheduler.cpp
@@ -349,25 +349,41 @@ TEST_CASE("[cpu/scheduler] SequentialScheduler(forward)") {
 }
 
 TEST_CASE("[cpu/scheduler] SequentialScheduler(backward)", "[scheduler][backward]") {
+
+    // create GraphView
     std::shared_ptr<GraphView> gv = Sequential({ReLU("relu0"), Sqrt("srqt0"), ReLU("relu1")});
 
     std::shared_ptr<Tensor> inputTensor =
-            std::make_shared<Tensor>(Array4D<int, 2, 1, 5, 5>{{{{{0, 1, 2, 3, 4},
-                                                                 {5, 6, 7, 8, 9},
-                                                                 {10, 11, 12, 13, 14},
-                                                                 {15, 16, 17, 18, 19},
-                                                                 {20, 21, 22, 23, 24}}},
-                                                               {{{25, 26, 27, 28, 29},
-                                                                 {30, 31, 32, 33, 34},
-                                                                 {35, 36, 37, 38, 39},
-                                                                 {40, 41, 42, 43, 44},
-                                                                 {45, 46, 47, 48, 49}}}}});
+            std::make_shared<Tensor>(Array4D<float, 2, 1, 5, 5>{{{{{0.0f,  1.0f,  2.0f,  3.0f,  4.0f},
+                                                                 {5.0f,  6.0f,  7.0f,  8.0f,  9.0f},
+                                                                {10.0f, 11.0f, 12.0f, 13.0f, 14.0f},
+                                                                {15.0f, 16.0f, 17.0f, 18.0f, 19.0f},
+                                                                {20.0f, 21.0f, 22.0f, 23.0f, 24.0f}}},
+                                                              {{{25.0f, 26.0f, 27.0f, 28.0f, 29.0f},
+                                                                {30.0f, 31.0f, 32.0f, 33.0f, 34.0f},
+                                                                {35.0f, 36.0f, 37.0f, 38.0f, 39.0f},
+                                                                {40.0f, 41.0f, 42.0f, 43.0f, 44.0f},
+                                                                {45.0f, 46.0f, 47.0f, 48.0f, 49.0f}}}}});
     auto label = inputTensor;
     // implem already set to default
     auto myProd = Producer(inputTensor, "prod");
     myProd -> addChild(gv);
-    gv -> compile("cpu", DataType::Int32);
+    gv -> compile("cpu", DataType::Float32);
     compile_gradient(gv);
     SequentialScheduler scheduler(gv);
-    scheduler.backward();
-}
\ No newline at end of file
+    scheduler.forward();
+    auto predictedOutput = gv->getOrderedOutputs()[0];
+
+    std::shared_ptr<Tensor> targetOutput =
+          std::make_shared<Tensor>(Array4D<float, 2, 1, 5, 5>{{{{{0.0f, 1.0f, 1.0f, 2.0f, 2.0f},
+                                                                 {2.0f, 2.0f, 3.0f, 3.0f, 3.0f},
+                                                                 {3.0f, 3.0f, 3.0f, 4.0f, 4.0f},
+                                                                 {4.0f, 4.0f, 4.0f, 4.0f, 4.0f},
+                                                                 {4.0f, 5.0f, 5.0f, 5.0f, 5.0f}}},
+                                                               {{{5.0f, 5.0f, 5.0f, 5.0f, 5.0f},
+                                                                 {5.0f, 6.0f, 6.0f, 6.0f, 6.0f},
+                                                                 {6.0f, 6.0f, 6.0f, 6.0f, 6.0f},
+                                                                 {6.0f, 6.0f, 6.0f, 7.0f, 7.0f},
+                                                                 {7.0f, 7.0f, 7.0f, 7.0f, 7.0f}}}}});
+
+    REQUIRE_NOTHROW(scheduler.backward(predictedOutput - targetOutput));
-- 
GitLab