diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 108f1f2b4af12b3501dbb247d17052e42ebb70ed..b4c5de2ebe5c18e91da8fe4474ea74cf338b0fa6 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -639,6 +639,7 @@ public: * the remaining coordinates are assumed to be 0. * Beware: the contiguous index will only correspond to the storage index * if the tensor is contiguous! + * Note that the coordIdx may be an empty vector. * * @param coordIdx Coordinate to an element in the tensor * @return DimSize_t Contiguous index @@ -646,12 +647,13 @@ public: std::size_t getIdx(const std::vector<std::size_t>& coordIdx) const { AIDGE_ASSERT(coordIdx.size() <= mDims.size(), "Coordinates does not match number of dimensions"); std::size_t flatIdx = 0; - std::size_t i = 0; - for(; i < coordIdx.size() - 1; ++i) { - AIDGE_ASSERT(coordIdx[i] < mDims[i], "Coordinates dimensions does not fit the dimensions of the tensor"); - flatIdx = (flatIdx + coordIdx[i]) * mDims[i + 1]; + for(std::size_t i = 0; i < mDims.size(); ++i) { + auto coord = i < coordIdx.size() ? coordIdx[i]: 0; + AIDGE_ASSERT(coord < mDims[i], "Coordinates dimensions does not fit the dimensions of the tensor"); + auto nextDimSize = i + 1 < mDims.size() ? mDims[i + 1]: 1; + flatIdx = (flatIdx + coord) * nextDimSize; } - return flatIdx + coordIdx[i]; + return flatIdx; } /** @@ -663,10 +665,10 @@ public: * @return DimSize_t Storage index */ std::size_t getStorageIdx(const std::vector<std::size_t>& coordIdx) const { + AIDGE_ASSERT(coordIdx.size() <= mDims.size(), "Coordinates does not match number of dimensions"); for(std::size_t i = 0; i < coordIdx.size(); ++i) { AIDGE_ASSERT(coordIdx[i] < mDims[i], "Coordinates dimensions does not fit the dimensions of the tensor"); } - AIDGE_ASSERT(coordIdx.size() <= mDims.size(), "Coordinates does not match number of dimensions"); return std::inner_product(coordIdx.cbegin(), coordIdx.cend(), mStrides.cbegin(), DimSize_t(0)); } diff --git a/unit_tests/data/Test_Tensor.cpp b/unit_tests/data/Test_Tensor.cpp index a536f113f7d11eb8cec81b5fdbf57909bd70611d..d5cd8cdcfb88beee9aab2393b0c5591c79a70b80 100644 --- a/unit_tests/data/Test_Tensor.cpp +++ b/unit_tests/data/Test_Tensor.cpp @@ -298,6 +298,73 @@ TEST_CASE("[core/data] Tensor(other)", "[Tensor][extract][zeros][print]") { } } + SECTION("Tensor set/get") { + // Test set with idx and get with idx and coords on different tensors ranks (including 0-rank) + // with different coords ranks (including 0-rank). + for (std::uint16_t trial = 0; trial < NBTRIALS; ++trial) { + // Test tensors of rank 0 to 3 + for (std::size_t nb_dims = 0; nb_dims <= 3; ++nb_dims) { + std::vector<std::size_t> dims(nb_dims); + for (std::size_t dim = 0; dim < nb_dims; ++dim) { + dims[dim] = dimsDist(gen); + } + + size_t size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()); + std::vector<float> values(size); + for (auto& valref : values) { + valref = valueDist(gen); + } + + std::unique_ptr<float[]> x_array(new float[size]); + for (std::size_t i = 0; i < size; ++i) { + x_array[i] = values[i]; + } + + // Initialize Tensor with a host backend + Tensor x{dims}; + x.setDataType(DataType::Float32); + x.setBackend("cpu"); + x.getImpl()->setRawPtr(x_array.get(), x.size()); + REQUIRE(x.getImpl()->hostPtr() != nullptr); + REQUIRE(x.isContiguous()); + + // Test get() and set() values by index + for (std::size_t i = 0; i < size; ++i) { + REQUIRE_NOTHROW(x.set(i, values[i])); + } + for (std::size_t i = 0; i < size; ++i) { + float val; + REQUIRE_NOTHROW(val = x.get<float>(i)); + REQUIRE(val == values[i]); + } + + // Test get() and set() by coords + // We create coords of rank 0 to the number of dimensions + for (std::size_t coord_size = 0; coord_size < dims.size(); ++coord_size) { + std::vector<std::size_t> coords(coord_size); + for (std::size_t coord_idx = 0; coord_idx < coord_size; ++coord_idx) { + std::size_t dim_idx = (dimsDist(gen)-1) % dims[coord_idx]; + coords[coord_idx] = dim_idx; + } + std::size_t flat_idx, flat_storage_idx; + // As it is continuous we have getIdx() == getStorageIdx() + REQUIRE_NOTHROW(flat_idx = x.getIdx(coords)); + REQUIRE_NOTHROW(flat_storage_idx = x.getStorageIdx(coords)); + REQUIRE(flat_storage_idx == flat_idx); + float val, val_flat; + // Test get() by index and by coords + REQUIRE_NOTHROW(val_flat = x.get<float>(flat_idx)); + REQUIRE_NOTHROW(val = x.get<float>(coords)); + REQUIRE(val == val_flat); + REQUIRE(val == values[flat_idx]); + // Test set() by coords, also update the reference array + REQUIRE_NOTHROW(x.set(coords, val + 1)); + values[flat_idx] += 1; + } + } + } + } + SECTION("Tensor extract") { bool equal;