diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp index f09cf5d91d60d05762cc86a8fb9ca1400cf1be8b..7526680d46ed86268fdf021f9da5d73f6d0b263f 100644 --- a/include/aidge/backend/cuda/data/TensorImpl.hpp +++ b/include/aidge/backend/cuda/data/TensorImpl.hpp @@ -62,22 +62,15 @@ class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_ { ** Referes to the cudnnSetTensorNdDescriptor documentation from : ** https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html **/ - std::vector<int> dims(4,1); - std::vector<int> strides(4,1); - int stride = 1; - - for (unsigned int dim = 0; dim < 4; ++dim) { - if(dim < mTensor.nbDims()) { - dims[dim] = mTensor.dims()[dim]; - strides[dim] = stride; - stride *= mTensor.dims()[dim]; - } - } + std::vector<int> dims(mTensor.dims().begin(), mTensor.dims().end()); + + if (dims.size() < 4) + dims.resize(4, 1); + + std::vector<int> strides(dims.size(), 1); - for (unsigned int dim = 4; dim < mTensor.nbDims(); ++dim) { - dims.push_back(mTensor.dims()[dim]); - strides.push_back(stride); - stride *= mTensor.dims()[dim]; + for (size_t dim = 1; dim < dims.size(); ++dim) { + strides[dims.size() - dim - 1] = strides[dims.size() - dim] * dims[dims.size() - dim]; } CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(mCudnnTensor,