From 4a72555375f997ef9ab850413b3ea4ba6890ec83 Mon Sep 17 00:00:00 2001 From: thibault allenet <thibault.allenet@cea.fr> Date: Mon, 12 Feb 2024 15:29:03 +0000 Subject: [PATCH] Update TensorImpl constructor to take the tensor dimensions instead of the number of elements. --- include/aidge/backend/cuda/data/TensorImpl.hpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp index 8b62b2d..d7ccdaa 100644 --- a/include/aidge/backend/cuda/data/TensorImpl.hpp +++ b/include/aidge/backend/cuda/data/TensorImpl.hpp @@ -69,12 +69,13 @@ private: public: static constexpr const char *Backend = "cuda"; - TensorImpl_cuda(DeviceIdx_t device, NbElts_t length) : TensorImpl(Backend, device, length), mDataOwner(nullptr, cudaDelete) {} + TensorImpl_cuda(DeviceIdx_t device, std::vector<DimSize_t> dims) : TensorImpl(Backend, device, dims), mDataOwner(nullptr, cudaDelete) {} + bool operator==(const TensorImpl &otherImpl) const override final; - static std::shared_ptr<TensorImpl_cuda> create(DeviceIdx_t device, NbElts_t length) { - return std::make_shared<TensorImpl_cuda<T>>(device, length); + static std::shared_ptr<TensorImpl_cuda> create(DeviceIdx_t device, std::vector<DimSize_t> dims) { + return std::make_shared<TensorImpl_cuda<T>>(device, dims); } // native interface -- GitLab