From 15242a20c2e3f0eec54d0d8d98f474ee1dccfa2e Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 5 Jan 2024 14:40:18 +0100 Subject: [PATCH] Fixed reviewed issues --- include/aidge/backend/cuda/data/TensorImpl.hpp | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp index be6f9f3..c61e926 100644 --- a/include/aidge/backend/cuda/data/TensorImpl.hpp +++ b/include/aidge/backend/cuda/data/TensorImpl.hpp @@ -79,7 +79,7 @@ public: std::size_t size() const override { return mData.size(); } std::size_t scalarSize() const override { return sizeof(T); } - void setDevice(int device) override { + void setDevice(DeviceIdx_t device) override { mDevice = device; } @@ -154,7 +154,7 @@ public: } } - void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, int>& device) override { + void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, DeviceIdx_t>& device) override { AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity"); CHECK_CUDA_STATUS(cudaMemcpy(rawPtr(), src, length * sizeof(T), cudaMemcpyDeviceToDevice)); } @@ -169,14 +169,14 @@ public: CHECK_CUDA_STATUS(cudaMemcpy(dst, rawPtr(), length * sizeof(T), cudaMemcpyDeviceToHost)); } - void *rawPtr() override { + void *rawPtr(NbElts_t offset = 0) override { lazyInit(); - return mData.data(); + return (mData.data() + offset); }; - const void *rawPtr() const override { + const void *rawPtr(NbElts_t offset = 0) const override { AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const rawPtr"); - return mData.data(); + return (mData.data() + offset); }; const cudnnTensorDescriptor_t& getCudnnTensorDesc() const override { @@ -212,11 +212,6 @@ public: return mCudnnTensor; } - void* getRawPtr(NbElts_t idx) override final { - AIDGE_ASSERT(idx < mData.size(), "idx out of range"); - return static_cast<void*>(static_cast<T*>(rawPtr()) + idx); - }; - void setRawPtr(void *ptr, NbElts_t length) override final { AIDGE_ASSERT(length >= mTensor.size(), "trying to set raw pointer of insufficient capacity"); mData = future_std::span<T>(static_cast<T *>(ptr), length); -- GitLab