Skip to content
Snippets Groups Projects
Commit 2d9f6d80 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

getIdx() accepts less dimensions

parent 80f23f80
No related branches found
No related tags found
No related merge requests found
......@@ -338,9 +338,7 @@ class Tensor : public Data,
*/
template <std::array<DimSize_t, 1>::size_type DIM> // deducing std::array size_type and declaring DIM accordingly
void resize(const std::array<DimSize_t, DIM> &dims) {
static_assert(DIM<=MaxDim,"Too many tensor dimensions required by resize, not supported");
mDims.assign(dims.begin(), dims.end());
computeSize();
resize(std::vector<DimSize_t>(dims.begin(), dims.end()));
}
/**
......@@ -504,9 +502,9 @@ class Tensor : public Data,
}
/**
* @brief From the the 1D index, return the coordinate of an element in the tensor.
* @brief From the the 1D contiguous index, return the coordinate of an element in the tensor.
*
* @param flatIdx 1D index of the value considering a flatten tensor.
* @param flatIdx 1D contiguous index of the value considering a flatten, contiguous, tensor.
* @return std::vector<DimSize_t>
*/
std::vector<std::size_t> getCoord(std::size_t flatIdx) const {
......@@ -521,19 +519,19 @@ class Tensor : public Data,
}
/**
* @brief From the coordinate returns the 1D index of an element in the tensor.
* @brief From the coordinate returns the 1D contiguous index of an element in the tensor.
* If the number of coordinates is inferior to the number of dimensions,
* the remaining coordinates are assumed to be 0.
*
* @param coordIdx Coordinate to an element in the tensor
* @return DimSize_t
* @return DimSize_t Contiguous index
*/
std::size_t getIdx(std::vector<std::size_t> coordIdx) const {
// std::size_t flatIdx = 0;
// std::size_t stride = 1;
std::size_t getIdx(const std::vector<std::size_t>& coordIdx) const {
AIDGE_ASSERT(coordIdx.size() <= mDims.size(), "Coordinates does not match number of dimensions");
std::size_t flatIdx = 0;
assert(coordIdx.size() == mDims.size() && "Coordinates does not match number of dimensions");
std::size_t i = 0;
for(; i < mDims.size() - 1; ++i){
assert(coordIdx[i] < mDims[i] && "Coordinates dimensions does not fit the dimensions of the tensor");
for(; i < coordIdx.size() - 1; ++i){
AIDGE_ASSERT(coordIdx[i] < mDims[i], "Coordinates dimensions does not fit the dimensions of the tensor");
flatIdx = (flatIdx + coordIdx[i]) * mDims[i + 1];
}
return flatIdx + coordIdx[i];
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment