From 225642bbe01a02d0a9c25cb9b58b7a26b9935a09 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Wed, 19 Jun 2024 18:11:21 +0200
Subject: [PATCH] Fixed grad check

---
 unit_tests/data/Test_Tensor.cpp | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/unit_tests/data/Test_Tensor.cpp b/unit_tests/data/Test_Tensor.cpp
index 655fd725e..62e90dcbd 100644
--- a/unit_tests/data/Test_Tensor.cpp
+++ b/unit_tests/data/Test_Tensor.cpp
@@ -40,7 +40,7 @@ TEST_CASE("[core/data] Tensor(Construction)", "[Tensor][Constructor]") {
             (T_default.dims() == std::vector<DimSize_t>({})) &&
             (T_default.strides() == std::vector<DimSize_t>({1})) &&
             (T_default.getImpl() == nullptr) &&
-            (T_default.grad() == nullptr) &&
+            (T_default.grad() != nullptr) &&
             (T_default.isContiguous() == true)
         ));
     }
@@ -53,7 +53,7 @@ TEST_CASE("[core/data] Tensor(Construction)", "[Tensor][Constructor]") {
             (T.dims() == std::vector<DimSize_t>({})) &&
             (T.strides() == std::vector<DimSize_t>({1})) &&
             (T.getImpl() != nullptr) &&
-            (T.grad() == nullptr) &&
+            (T.grad() != nullptr) &&
             (T.isContiguous() == true)
         ));
     }
@@ -67,7 +67,7 @@ TEST_CASE("[core/data] Tensor(Construction)", "[Tensor][Constructor]") {
             (T.dims() == Tdims) &&
             (T.strides() == std::vector<DimSize_t>({5040,2520,840,210,42,7,1})) &&
             (T.getImpl() == nullptr) &&
-            (T.grad() == nullptr) &&
+            (T.grad() != nullptr) &&
             (T.isContiguous() == true)
         ));
     }
@@ -83,7 +83,7 @@ TEST_CASE("[core/data] Tensor(Construction)", "[Tensor][Constructor]") {
             (T.dims() == std::vector<DimSize_t>({2})) &&
             (T.strides() == std::vector<DimSize_t>({1})) &&
             (T.getImpl() != nullptr) &&
-            (T.grad() == nullptr) &&
+            (T.grad() != nullptr) &&
             (T.isContiguous() == true)
         ));
 
@@ -97,7 +97,7 @@ TEST_CASE("[core/data] Tensor(Construction)", "[Tensor][Constructor]") {
             (T.dims() == std::vector<DimSize_t>({2,2,2})) &&
             (T.strides() == std::vector<DimSize_t>({4,2,1})) &&
             (T.getImpl() != nullptr) &&
-            (T.grad() == nullptr) &&
+            (T.grad() != nullptr) &&
             (T.isContiguous() == true)
         ));
         REQUIRE_NOTHROW(T = Array3D<int, 2, 2, 2>{{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}});
@@ -113,7 +113,7 @@ TEST_CASE("[core/data] Tensor(Construction)", "[Tensor][Constructor]") {
             (T.dims() == std::vector<DimSize_t>({2,2,2,2})) &&
             (T.strides() == std::vector<DimSize_t>({8,4,2,1})) &&
             (T.getImpl() != nullptr) &&
-            (T.grad() == nullptr) &&
+            (T.grad() != nullptr) &&
             (T.isContiguous() == true)
         ));
     }
@@ -157,7 +157,7 @@ TEST_CASE("[core/data] Tensor(Construction)", "[Tensor][Constructor]") {
                 (T.dims() == Tclone.dims()) &&
                 (T.strides() == Tclone.strides()) &&
                 (T.getImpl() != Tclone.getImpl()) &&
-                (Tclone.grad() == nullptr) &&
+                (Tclone.grad() != nullptr) &&
                 (Tclone.isContiguous() == true)
             ));
             REQUIRE(Tclone == T);
-- 
GitLab