From 0e5611a5c4ba9cf1aeae6a10614c4435306fdaf2 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Thu, 21 Sep 2023 13:28:00 +0200 Subject: [PATCH] Fixed Conv unit test --- .../aidge/backend/cuda/data/TensorImpl.hpp | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp index f09cf5d..7526680 100644 --- a/include/aidge/backend/cuda/data/TensorImpl.hpp +++ b/include/aidge/backend/cuda/data/TensorImpl.hpp @@ -62,22 +62,15 @@ class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_ { ** Referes to the cudnnSetTensorNdDescriptor documentation from : ** https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html **/ - std::vector<int> dims(4,1); - std::vector<int> strides(4,1); - int stride = 1; - - for (unsigned int dim = 0; dim < 4; ++dim) { - if(dim < mTensor.nbDims()) { - dims[dim] = mTensor.dims()[dim]; - strides[dim] = stride; - stride *= mTensor.dims()[dim]; - } - } + std::vector<int> dims(mTensor.dims().begin(), mTensor.dims().end()); + + if (dims.size() < 4) + dims.resize(4, 1); + + std::vector<int> strides(dims.size(), 1); - for (unsigned int dim = 4; dim < mTensor.nbDims(); ++dim) { - dims.push_back(mTensor.dims()[dim]); - strides.push_back(stride); - stride *= mTensor.dims()[dim]; + for (size_t dim = 1; dim < dims.size(); ++dim) { + strides[dims.size() - dim - 1] = strides[dims.size() - dim] * dims[dims.size() - dim]; } CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(mCudnnTensor, -- GitLab