From 15242a20c2e3f0eec54d0d8d98f474ee1dccfa2e Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Fri, 5 Jan 2024 14:40:18 +0100
Subject: [PATCH] Fixed reviewed issues

---
 include/aidge/backend/cuda/data/TensorImpl.hpp | 17 ++++++-----------
 1 file changed, 6 insertions(+), 11 deletions(-)

diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp
index be6f9f3..c61e926 100644
--- a/include/aidge/backend/cuda/data/TensorImpl.hpp
+++ b/include/aidge/backend/cuda/data/TensorImpl.hpp
@@ -79,7 +79,7 @@ public:
     std::size_t size() const override { return mData.size(); }
     std::size_t scalarSize() const override { return sizeof(T); }
 
-    void setDevice(int device) override {
+    void setDevice(DeviceIdx_t device) override {
         mDevice = device;
     }
 
@@ -154,7 +154,7 @@ public:
         }
     }
 
-    void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, int>& device) override {
+    void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, DeviceIdx_t>& device) override {
         AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity");
         CHECK_CUDA_STATUS(cudaMemcpy(rawPtr(), src, length * sizeof(T), cudaMemcpyDeviceToDevice));
     }
@@ -169,14 +169,14 @@ public:
         CHECK_CUDA_STATUS(cudaMemcpy(dst, rawPtr(), length * sizeof(T), cudaMemcpyDeviceToHost));
     }
 
-    void *rawPtr() override {
+    void *rawPtr(NbElts_t offset = 0) override {
         lazyInit();
-        return mData.data();
+        return (mData.data() + offset);
     };
 
-    const void *rawPtr() const override {
+    const void *rawPtr(NbElts_t offset = 0) const override {
         AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const rawPtr");
-        return mData.data();
+        return (mData.data() + offset);
     };
 
     const cudnnTensorDescriptor_t& getCudnnTensorDesc() const override {
@@ -212,11 +212,6 @@ public:
         return mCudnnTensor;
     }
 
-    void* getRawPtr(NbElts_t idx) override final {
-        AIDGE_ASSERT(idx < mData.size(), "idx out of range");
-        return static_cast<void*>(static_cast<T*>(rawPtr()) + idx);
-    };
-
     void setRawPtr(void *ptr, NbElts_t length) override final {
         AIDGE_ASSERT(length >= mTensor.size(), "trying to set raw pointer of insufficient capacity");
         mData = future_std::span<T>(static_cast<T *>(ptr), length);
-- 
GitLab