diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp index be6f9f3ded1e7507af4d4c54a4e6cc6ecfc0438b..c61e926c88a9baf1fcdf64794c2a975a1b891356 100644 --- a/include/aidge/backend/cuda/data/TensorImpl.hpp +++ b/include/aidge/backend/cuda/data/TensorImpl.hpp @@ -79,7 +79,7 @@ public: std::size_t size() const override { return mData.size(); } std::size_t scalarSize() const override { return sizeof(T); } - void setDevice(int device) override { + void setDevice(DeviceIdx_t device) override { mDevice = device; } @@ -154,7 +154,7 @@ public: } } - void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, int>& device) override { + void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, DeviceIdx_t>& device) override { AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity"); CHECK_CUDA_STATUS(cudaMemcpy(rawPtr(), src, length * sizeof(T), cudaMemcpyDeviceToDevice)); } @@ -169,14 +169,14 @@ public: CHECK_CUDA_STATUS(cudaMemcpy(dst, rawPtr(), length * sizeof(T), cudaMemcpyDeviceToHost)); } - void *rawPtr() override { + void *rawPtr(NbElts_t offset = 0) override { lazyInit(); - return mData.data(); + return (mData.data() + offset); }; - const void *rawPtr() const override { + const void *rawPtr(NbElts_t offset = 0) const override { AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const rawPtr"); - return mData.data(); + return (mData.data() + offset); }; const cudnnTensorDescriptor_t& getCudnnTensorDesc() const override { @@ -212,11 +212,6 @@ public: return mCudnnTensor; } - void* getRawPtr(NbElts_t idx) override final { - AIDGE_ASSERT(idx < mData.size(), "idx out of range"); - return static_cast<void*>(static_cast<T*>(rawPtr()) + idx); - }; - void setRawPtr(void *ptr, NbElts_t length) override final { AIDGE_ASSERT(length >= mTensor.size(), "trying to set raw pointer of insufficient capacity"); mData = future_std::span<T>(static_cast<T *>(ptr), length);