From 4a72555375f997ef9ab850413b3ea4ba6890ec83 Mon Sep 17 00:00:00 2001
From: thibault allenet <thibault.allenet@cea.fr>
Date: Mon, 12 Feb 2024 15:29:03 +0000
Subject: [PATCH] Update TensorImpl constructor to take the tensor  dimensions
 instead of the number of elements.

---
 include/aidge/backend/cuda/data/TensorImpl.hpp | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp
index 8b62b2d..d7ccdaa 100644
--- a/include/aidge/backend/cuda/data/TensorImpl.hpp
+++ b/include/aidge/backend/cuda/data/TensorImpl.hpp
@@ -69,12 +69,13 @@ private:
 public:
     static constexpr const char *Backend = "cuda";
 
-    TensorImpl_cuda(DeviceIdx_t device, NbElts_t length) : TensorImpl(Backend, device, length), mDataOwner(nullptr, cudaDelete) {}
+    TensorImpl_cuda(DeviceIdx_t device, std::vector<DimSize_t> dims) : TensorImpl(Backend, device, dims), mDataOwner(nullptr, cudaDelete) {}
+    
 
     bool operator==(const TensorImpl &otherImpl) const override final;
 
-    static std::shared_ptr<TensorImpl_cuda> create(DeviceIdx_t device, NbElts_t length) {
-        return std::make_shared<TensorImpl_cuda<T>>(device, length);
+    static std::shared_ptr<TensorImpl_cuda> create(DeviceIdx_t device, std::vector<DimSize_t> dims) {
+        return std::make_shared<TensorImpl_cuda<T>>(device, dims);
     }
 
     // native interface
-- 
GitLab