From da92267d078f037928b051ebb10ef59e0bc9697c Mon Sep 17 00:00:00 2001
From: Christophe Guillon <christophe.guillon@inria.fr>
Date: Fri, 12 Jul 2024 13:42:28 +0200
Subject: [PATCH] [Test] Add unit tests for Tensor access by coords

Add tests for get()/set()/getIdx() of tensor values by index
and by coordinates.

Includes tests of 0-rank coordinates which should work and
return the first element which is also the single value for
a 0-rank tensor.
---
 unit_tests/data/Test_Tensor.cpp | 67 +++++++++++++++++++++++++++++++++
 1 file changed, 67 insertions(+)

diff --git a/unit_tests/data/Test_Tensor.cpp b/unit_tests/data/Test_Tensor.cpp
index a536f113f..d5cd8cdcf 100644
--- a/unit_tests/data/Test_Tensor.cpp
+++ b/unit_tests/data/Test_Tensor.cpp
@@ -298,6 +298,73 @@ TEST_CASE("[core/data] Tensor(other)", "[Tensor][extract][zeros][print]") {
         }
     }
 
+    SECTION("Tensor set/get") {
+        // Test set with idx and get with idx and coords on different tensors ranks (including 0-rank)
+        // with different coords ranks (including 0-rank).
+        for (std::uint16_t trial = 0; trial < NBTRIALS; ++trial) {
+            // Test tensors of rank 0 to 3
+            for (std::size_t nb_dims = 0; nb_dims <= 3; ++nb_dims) {
+                std::vector<std::size_t> dims(nb_dims);
+                for (std::size_t dim = 0; dim < nb_dims; ++dim) {
+                    dims[dim] = dimsDist(gen);
+                }
+
+                size_t size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>());
+                std::vector<float> values(size);
+                for (auto& valref : values) {
+                    valref = valueDist(gen);
+                }
+
+                std::unique_ptr<float[]> x_array(new float[size]);
+                for (std::size_t i = 0; i < size; ++i) {
+                    x_array[i] = values[i];
+                }
+
+                // Initialize Tensor with a host backend
+                Tensor x{dims};
+                x.setDataType(DataType::Float32);
+                x.setBackend("cpu");
+                x.getImpl()->setRawPtr(x_array.get(), x.size());
+                REQUIRE(x.getImpl()->hostPtr() != nullptr);
+                REQUIRE(x.isContiguous());
+
+                // Test get() and set() values by index
+                for (std::size_t i = 0; i < size; ++i) {
+                    REQUIRE_NOTHROW(x.set(i, values[i]));
+                }
+                for (std::size_t i = 0; i < size; ++i) {
+                    float val;
+                    REQUIRE_NOTHROW(val = x.get<float>(i));
+                    REQUIRE(val == values[i]);
+                }
+
+                // Test get() and set() by coords
+                // We create coords of rank 0 to the number of dimensions
+                for (std::size_t coord_size = 0; coord_size < dims.size(); ++coord_size) {
+                    std::vector<std::size_t> coords(coord_size);
+                    for (std::size_t coord_idx = 0; coord_idx < coord_size; ++coord_idx) {
+                        std::size_t dim_idx = (dimsDist(gen)-1) % dims[coord_idx];
+                        coords[coord_idx] = dim_idx;
+                    }
+                    std::size_t flat_idx, flat_storage_idx;
+                    // As it is continuous we have getIdx() == getStorageIdx()
+                    REQUIRE_NOTHROW(flat_idx = x.getIdx(coords));
+                    REQUIRE_NOTHROW(flat_storage_idx = x.getStorageIdx(coords));
+                    REQUIRE(flat_storage_idx == flat_idx);
+                    float val, val_flat;
+                    // Test get() by index and by coords
+                    REQUIRE_NOTHROW(val_flat = x.get<float>(flat_idx));
+                    REQUIRE_NOTHROW(val = x.get<float>(coords));
+                    REQUIRE(val == val_flat);
+                    REQUIRE(val == values[flat_idx]);
+                    // Test set() by coords, also update the reference array
+                    REQUIRE_NOTHROW(x.set(coords, val + 1));
+                    values[flat_idx] += 1;
+                }
+            }
+        }
+    }
+
     SECTION("Tensor extract") {
         bool equal;
 
-- 
GitLab