diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index cca09afdd9e4a2fa694f405085264a6d332884a9..a5680f92760a615ea807a3a137da3c49d3652be7 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -338,9 +338,7 @@ class Tensor : public Data, */ template <std::array<DimSize_t, 1>::size_type DIM> // deducing std::array size_type and declaring DIM accordingly void resize(const std::array<DimSize_t, DIM> &dims) { - static_assert(DIM<=MaxDim,"Too many tensor dimensions required by resize, not supported"); - mDims.assign(dims.begin(), dims.end()); - computeSize(); + resize(std::vector<DimSize_t>(dims.begin(), dims.end())); } /** @@ -504,9 +502,9 @@ class Tensor : public Data, } /** - * @brief From the the 1D index, return the coordinate of an element in the tensor. + * @brief From the the 1D contiguous index, return the coordinate of an element in the tensor. * - * @param flatIdx 1D index of the value considering a flatten tensor. + * @param flatIdx 1D contiguous index of the value considering a flatten, contiguous, tensor. * @return std::vector<DimSize_t> */ std::vector<std::size_t> getCoord(std::size_t flatIdx) const { @@ -521,19 +519,19 @@ class Tensor : public Data, } /** - * @brief From the coordinate returns the 1D index of an element in the tensor. + * @brief From the coordinate returns the 1D contiguous index of an element in the tensor. + * If the number of coordinates is inferior to the number of dimensions, + * the remaining coordinates are assumed to be 0. * * @param coordIdx Coordinate to an element in the tensor - * @return DimSize_t + * @return DimSize_t Contiguous index */ - std::size_t getIdx(std::vector<std::size_t> coordIdx) const { - // std::size_t flatIdx = 0; - // std::size_t stride = 1; + std::size_t getIdx(const std::vector<std::size_t>& coordIdx) const { + AIDGE_ASSERT(coordIdx.size() <= mDims.size(), "Coordinates does not match number of dimensions"); std::size_t flatIdx = 0; - assert(coordIdx.size() == mDims.size() && "Coordinates does not match number of dimensions"); std::size_t i = 0; - for(; i < mDims.size() - 1; ++i){ - assert(coordIdx[i] < mDims[i] && "Coordinates dimensions does not fit the dimensions of the tensor"); + for(; i < coordIdx.size() - 1; ++i){ + AIDGE_ASSERT(coordIdx[i] < mDims[i], "Coordinates dimensions does not fit the dimensions of the tensor"); flatIdx = (flatIdx + coordIdx[i]) * mDims[i + 1]; } return flatIdx + coordIdx[i];