From 225642bbe01a02d0a9c25cb9b58b7a26b9935a09 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Wed, 19 Jun 2024 18:11:21 +0200 Subject: [PATCH] Fixed grad check --- unit_tests/data/Test_Tensor.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/unit_tests/data/Test_Tensor.cpp b/unit_tests/data/Test_Tensor.cpp index 655fd725e..62e90dcbd 100644 --- a/unit_tests/data/Test_Tensor.cpp +++ b/unit_tests/data/Test_Tensor.cpp @@ -40,7 +40,7 @@ TEST_CASE("[core/data] Tensor(Construction)", "[Tensor][Constructor]") { (T_default.dims() == std::vector<DimSize_t>({})) && (T_default.strides() == std::vector<DimSize_t>({1})) && (T_default.getImpl() == nullptr) && - (T_default.grad() == nullptr) && + (T_default.grad() != nullptr) && (T_default.isContiguous() == true) )); } @@ -53,7 +53,7 @@ TEST_CASE("[core/data] Tensor(Construction)", "[Tensor][Constructor]") { (T.dims() == std::vector<DimSize_t>({})) && (T.strides() == std::vector<DimSize_t>({1})) && (T.getImpl() != nullptr) && - (T.grad() == nullptr) && + (T.grad() != nullptr) && (T.isContiguous() == true) )); } @@ -67,7 +67,7 @@ TEST_CASE("[core/data] Tensor(Construction)", "[Tensor][Constructor]") { (T.dims() == Tdims) && (T.strides() == std::vector<DimSize_t>({5040,2520,840,210,42,7,1})) && (T.getImpl() == nullptr) && - (T.grad() == nullptr) && + (T.grad() != nullptr) && (T.isContiguous() == true) )); } @@ -83,7 +83,7 @@ TEST_CASE("[core/data] Tensor(Construction)", "[Tensor][Constructor]") { (T.dims() == std::vector<DimSize_t>({2})) && (T.strides() == std::vector<DimSize_t>({1})) && (T.getImpl() != nullptr) && - (T.grad() == nullptr) && + (T.grad() != nullptr) && (T.isContiguous() == true) )); @@ -97,7 +97,7 @@ TEST_CASE("[core/data] Tensor(Construction)", "[Tensor][Constructor]") { (T.dims() == std::vector<DimSize_t>({2,2,2})) && (T.strides() == std::vector<DimSize_t>({4,2,1})) && (T.getImpl() != nullptr) && - (T.grad() == nullptr) && + (T.grad() != nullptr) && (T.isContiguous() == true) )); REQUIRE_NOTHROW(T = Array3D<int, 2, 2, 2>{{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}}); @@ -113,7 +113,7 @@ TEST_CASE("[core/data] Tensor(Construction)", "[Tensor][Constructor]") { (T.dims() == std::vector<DimSize_t>({2,2,2,2})) && (T.strides() == std::vector<DimSize_t>({8,4,2,1})) && (T.getImpl() != nullptr) && - (T.grad() == nullptr) && + (T.grad() != nullptr) && (T.isContiguous() == true) )); } @@ -157,7 +157,7 @@ TEST_CASE("[core/data] Tensor(Construction)", "[Tensor][Constructor]") { (T.dims() == Tclone.dims()) && (T.strides() == Tclone.strides()) && (T.getImpl() != Tclone.getImpl()) && - (Tclone.grad() == nullptr) && + (Tclone.grad() != nullptr) && (Tclone.isContiguous() == true) )); REQUIRE(Tclone == T); -- GitLab