Skip to content
Snippets Groups Projects
Commit 2ea2cd6b authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'user/guillon/dev/fix-tensor-getidx' into 'dev'

[Tensor] Fix invalid getIdx() method

See merge request eclipse/aidge/aidge_core!171
parents 9aad44ea e44d7267
No related branches found
No related tags found
No related merge requests found
......@@ -639,6 +639,7 @@ public:
* the remaining coordinates are assumed to be 0.
* Beware: the contiguous index will only correspond to the storage index
* if the tensor is contiguous!
* Note that the coordIdx may be an empty vector.
*
* @param coordIdx Coordinate to an element in the tensor
* @return DimSize_t Contiguous index
......@@ -646,12 +647,13 @@ public:
std::size_t getIdx(const std::vector<std::size_t>& coordIdx) const {
AIDGE_ASSERT(coordIdx.size() <= mDims.size(), "Coordinates does not match number of dimensions");
std::size_t flatIdx = 0;
std::size_t i = 0;
for(; i < coordIdx.size() - 1; ++i) {
AIDGE_ASSERT(coordIdx[i] < mDims[i], "Coordinates dimensions does not fit the dimensions of the tensor");
flatIdx = (flatIdx + coordIdx[i]) * mDims[i + 1];
for(std::size_t i = 0; i < mDims.size(); ++i) {
auto coord = i < coordIdx.size() ? coordIdx[i]: 0;
AIDGE_ASSERT(coord < mDims[i], "Coordinates dimensions does not fit the dimensions of the tensor");
auto nextDimSize = i + 1 < mDims.size() ? mDims[i + 1]: 1;
flatIdx = (flatIdx + coord) * nextDimSize;
}
return flatIdx + coordIdx[i];
return flatIdx;
}
/**
......@@ -663,10 +665,10 @@ public:
* @return DimSize_t Storage index
*/
std::size_t getStorageIdx(const std::vector<std::size_t>& coordIdx) const {
AIDGE_ASSERT(coordIdx.size() <= mDims.size(), "Coordinates does not match number of dimensions");
for(std::size_t i = 0; i < coordIdx.size(); ++i) {
AIDGE_ASSERT(coordIdx[i] < mDims[i], "Coordinates dimensions does not fit the dimensions of the tensor");
}
AIDGE_ASSERT(coordIdx.size() <= mDims.size(), "Coordinates does not match number of dimensions");
return std::inner_product(coordIdx.cbegin(), coordIdx.cend(), mStrides.cbegin(), DimSize_t(0));
}
......
......@@ -298,6 +298,73 @@ TEST_CASE("[core/data] Tensor(other)", "[Tensor][extract][zeros][print]") {
}
}
SECTION("Tensor set/get") {
// Test set with idx and get with idx and coords on different tensors ranks (including 0-rank)
// with different coords ranks (including 0-rank).
for (std::uint16_t trial = 0; trial < NBTRIALS; ++trial) {
// Test tensors of rank 0 to 3
for (std::size_t nb_dims = 0; nb_dims <= 3; ++nb_dims) {
std::vector<std::size_t> dims(nb_dims);
for (std::size_t dim = 0; dim < nb_dims; ++dim) {
dims[dim] = dimsDist(gen);
}
size_t size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>());
std::vector<float> values(size);
for (auto& valref : values) {
valref = valueDist(gen);
}
std::unique_ptr<float[]> x_array(new float[size]);
for (std::size_t i = 0; i < size; ++i) {
x_array[i] = values[i];
}
// Initialize Tensor with a host backend
Tensor x{dims};
x.setDataType(DataType::Float32);
x.setBackend("cpu");
x.getImpl()->setRawPtr(x_array.get(), x.size());
REQUIRE(x.getImpl()->hostPtr() != nullptr);
REQUIRE(x.isContiguous());
// Test get() and set() values by index
for (std::size_t i = 0; i < size; ++i) {
REQUIRE_NOTHROW(x.set(i, values[i]));
}
for (std::size_t i = 0; i < size; ++i) {
float val;
REQUIRE_NOTHROW(val = x.get<float>(i));
REQUIRE(val == values[i]);
}
// Test get() and set() by coords
// We create coords of rank 0 to the number of dimensions
for (std::size_t coord_size = 0; coord_size < dims.size(); ++coord_size) {
std::vector<std::size_t> coords(coord_size);
for (std::size_t coord_idx = 0; coord_idx < coord_size; ++coord_idx) {
std::size_t dim_idx = (dimsDist(gen)-1) % dims[coord_idx];
coords[coord_idx] = dim_idx;
}
std::size_t flat_idx, flat_storage_idx;
// As it is continuous we have getIdx() == getStorageIdx()
REQUIRE_NOTHROW(flat_idx = x.getIdx(coords));
REQUIRE_NOTHROW(flat_storage_idx = x.getStorageIdx(coords));
REQUIRE(flat_storage_idx == flat_idx);
float val, val_flat;
// Test get() by index and by coords
REQUIRE_NOTHROW(val_flat = x.get<float>(flat_idx));
REQUIRE_NOTHROW(val = x.get<float>(coords));
REQUIRE(val == val_flat);
REQUIRE(val == values[flat_idx]);
// Test set() by coords, also update the reference array
REQUIRE_NOTHROW(x.set(coords, val + 1));
values[flat_idx] += 1;
}
}
}
}
SECTION("Tensor extract") {
bool equal;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment