From 0e5611a5c4ba9cf1aeae6a10614c4435306fdaf2 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Thu, 21 Sep 2023 13:28:00 +0200
Subject: [PATCH] Fixed Conv unit test

---
 .../aidge/backend/cuda/data/TensorImpl.hpp    | 23 +++++++------------
 1 file changed, 8 insertions(+), 15 deletions(-)

diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp
index f09cf5d..7526680 100644
--- a/include/aidge/backend/cuda/data/TensorImpl.hpp
+++ b/include/aidge/backend/cuda/data/TensorImpl.hpp
@@ -62,22 +62,15 @@ class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_  {
                 **      Referes to the cudnnSetTensorNdDescriptor documentation from :
                 **      https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html
                 **/
-                std::vector<int> dims(4,1);
-                std::vector<int> strides(4,1);
-                int stride = 1;
-
-                for (unsigned int dim = 0; dim < 4; ++dim) {
-                    if(dim < mTensor.nbDims()) {
-                        dims[dim] = mTensor.dims()[dim];
-                        strides[dim] = stride;
-                        stride  *= mTensor.dims()[dim];
-                    }
-                }
+                std::vector<int> dims(mTensor.dims().begin(), mTensor.dims().end());
+
+                if (dims.size() < 4)
+                    dims.resize(4, 1);
+
+                std::vector<int> strides(dims.size(), 1);
 
-                for (unsigned int dim = 4; dim < mTensor.nbDims(); ++dim) {
-                    dims.push_back(mTensor.dims()[dim]);
-                    strides.push_back(stride);
-                    stride *= mTensor.dims()[dim];
+                for (size_t dim = 1; dim < dims.size(); ++dim) {
+                    strides[dims.size() - dim - 1] = strides[dims.size() - dim] * dims[dims.size() - dim];
                 }
 
                 CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(mCudnnTensor,
-- 
GitLab