diff --git a/unit_tests/Test_ReLUImpl.cpp b/unit_tests/Test_ReLUImpl.cpp
index e86d3126f39c366abe1c8cbcd0d7086ffe477c4c..1ac50c29d4b98cc5311bf270e05206fe64ce3b30 100644
--- a/unit_tests/Test_ReLUImpl.cpp
+++ b/unit_tests/Test_ReLUImpl.cpp
@@ -17,12 +17,13 @@
 #include "aidge/backend/cpu.hpp"
 #include "aidge/backend/cuda.hpp"
 #include "aidge/data/Tensor.hpp"
+#include "aidge/utils/TensorUtils.hpp"
 
 using namespace Aidge;
 
 
 TEST_CASE("[gpu/operator] ReLU(forward)", "[ReLU][GPU]") {
-    SECTION("4D Tensor") {
+    SECTION("Constant Input") {
         std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array4D<float,2,2,2,10> {
             {
                 {
@@ -98,30 +99,23 @@ TEST_CASE("[gpu/operator] ReLU(forward)", "[ReLU][GPU]") {
         std::mt19937 gen(rd());
         std::uniform_real_distribution<float> valueDist(
             0.1f, 1.1f); // Random float distribution between 0 and 1
-        std::uniform_int_distribution<std::size_t> dimSizeDist(std::size_t(2),
+        std::uniform_int_distribution<std::size_t> dimSizeDist(std::size_t(1),
                                                                std::size_t(10));
 
-        std::uniform_int_distribution<std::size_t> nbDimsDist(std::size_t(2), std::size_t(4));
-
-        // Create ReLU Operator
-        std::shared_ptr<Node> myReLU = ReLU("myReLU");
-        auto op = std::static_pointer_cast<OperatorTensor>(myReLU->getOperator());
-        op->setDataType(DataType::Float32);
-        op->setBackend("cuda");
-
-        // Create the input Tensor
-        std::shared_ptr<Tensor> T0 = std::make_shared<Tensor>();
-        op->associateInput(0, T0);
-        T0->setDataType(DataType::Float32);
-        T0->setBackend("cuda");
-
-        // To measure execution time of 'AveragePooling_Op::forward()'
+        std::uniform_int_distribution<std::size_t> nbDimsDist(std::size_t(1), std::size_t(8)); // Max nbDims supported by cudnn is 8
+        // To measure execution time of 'forward()'
         std::chrono::time_point<std::chrono::system_clock> start;
         std::chrono::time_point<std::chrono::system_clock> end;
         std::chrono::duration<double, std::micro> duration{};
         std::size_t number_of_operation = 0;
         for (std::uint16_t trial = 0; trial < NBTRIALS; ++trial)
         {
+            // Create ReLU Operator
+            std::shared_ptr<Node> myReLU = ReLU("myReLU");
+            auto op = std::static_pointer_cast<OperatorTensor>(myReLU->getOperator());
+            op->setDataType(DataType::Float32);
+            op->setBackend("cuda");
+
             // generate a random Tensor
             const std::size_t nbDims = nbDimsDist(gen);
             std::vector<std::size_t> dims;
@@ -133,6 +127,13 @@ TEST_CASE("[gpu/operator] ReLU(forward)", "[ReLU][GPU]") {
             const std::size_t nb_elements = std::accumulate(dims.cbegin(), dims.cend(), std::size_t(1), std::multiplies<std::size_t>());
             number_of_operation += nb_elements;
 
+            // Create the input Tensor
+            std::shared_ptr<Tensor> T0 = std::make_shared<Tensor>();
+            T0->setDataType(DataType::Float32);
+            T0->setBackend("cuda");
+            T0->resize(dims);
+            op->associateInput(0, T0);
+
             // Fill input tensor
             float *input_h = new float[nb_elements];
             float *output_h = new float[nb_elements];
@@ -145,7 +146,6 @@ TEST_CASE("[gpu/operator] ReLU(forward)", "[ReLU][GPU]") {
             float *input_d;
             cudaMalloc(reinterpret_cast<void **>(&input_d), sizeof(float) * nb_elements);
             cudaMemcpy(input_d, input_h, sizeof(float) * nb_elements, cudaMemcpyHostToDevice);
-            T0->resize(dims);
             T0->getImpl()->setRawPtr(input_d, nb_elements);
 
             // Run inference
@@ -158,10 +158,7 @@ TEST_CASE("[gpu/operator] ReLU(forward)", "[ReLU][GPU]") {
             float *computedOutput = new float[nb_elements]();
             cudaMemcpy(computedOutput, op->getOutput(0)->getImpl()->rawPtr(), sizeof(float) * nb_elements, cudaMemcpyDeviceToHost);
 
-            for (int i = 0; i < nb_elements; ++i)
-            {
-                REQUIRE(computedOutput[i] == output_h[i]);
-            }
+            REQUIRE(approxEq<float>(*computedOutput, *output_h));
 
             delete[] computedOutput;
             delete[] input_h;