From 39cf8659792837f85a9e0ac1b64b157a7adc76c6 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Thu, 7 Dec 2023 11:09:41 +0100
Subject: [PATCH] Added automatic GPU reallocation

---
 .../aidge/backend/cuda/data/TensorImpl.hpp    | 57 +++++++++++++------
 src/data/TensorImpl.cu                        |  4 +-
 2 files changed, 41 insertions(+), 20 deletions(-)

diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp
index 80fc8d6..4f66a93 100644
--- a/include/aidge/backend/cuda/data/TensorImpl.hpp
+++ b/include/aidge/backend/cuda/data/TensorImpl.hpp
@@ -6,6 +6,7 @@
 #include "aidge/utils/Registrar.hpp"
 #include "aidge/utils/Types.h"
 #include "aidge/utils/ErrorHandling.hpp"
+#include "aidge/utils/future_std/span.hpp"
 
 #include "aidge/backend/cuda/utils/CudaUtils.hpp"
 #include "aidge/backend/cuda/utils/CudaContext.hpp"
@@ -34,16 +35,31 @@ public:
 
 template <class T>
 class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_  {
-   private:
+private:
+    static T* cudaAlloc(NbElts_t length) {
+        T* data;
+        CHECK_CUDA_STATUS(cudaMalloc(reinterpret_cast<void**>(&data), length * sizeof(T)));
+        return data;
+    }
+
+    static void cudaDelete(T* data) {
+        // Should not be called if data is nullptr, according to the standard
+        cudaFree(data);
+    }
+
+private:
     const Tensor &mTensor;  // Impl needs to access Tensor information, but is not
                             // supposed to change it!
-    T* mData = nullptr;
+    /// Pointer to the data and its capacity
+    future_std::span<T> mData;
+    /// If this instance own the data, std::unique_ptr manages it
+    std::unique_ptr<T, decltype(&cudaDelete)> mDataOwner;
     mutable cudnnTensorDescriptor_t mCudnnTensor = nullptr;
 
-   public:
+public:
     static constexpr const char *Backend = "cuda";
 
-    TensorImpl_cuda(const Tensor &tensor) : TensorImpl(Backend), mTensor(tensor) {}
+    TensorImpl_cuda(const Tensor &tensor) : TensorImpl(Backend), mTensor(tensor), mDataOwner(nullptr, cudaDelete) {}
 
     bool operator==(const TensorImpl &otherImpl) const override final;
 
@@ -52,7 +68,7 @@ class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_  {
     }
 
     // native interface
-    const T* data() const { return mData; }
+    const future_std::span<T>& data() const { return mData; }
 
     std::size_t scalarSize() const override { return sizeof(T); }
 
@@ -133,13 +149,13 @@ class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_  {
     }
 
     void *rawPtr() override {
-        lazyInit(reinterpret_cast<void**>(&mData));
-        return mData;
+        lazyInit();
+        return mData.data();
     };
 
     const void *rawPtr() const override {
-        AIDGE_ASSERT(mData != nullptr, "accessing uninitialized const rawPtr");
-        return mData;
+        AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const rawPtr");
+        return mData.data();
     };
 
     void* getRaw(std::size_t idx) {
@@ -180,21 +196,26 @@ class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_  {
     }
 
     virtual ~TensorImpl_cuda() {
-        if (mData != nullptr)
-            cudaFree(mData);
-
         if (mCudnnTensor != nullptr)
             cudnnDestroyTensorDescriptor(mCudnnTensor);
     }
 
-    void setRawPtr(void* /*ptr*/) override final {
-        printf("Not implemented yet.");
+    void setRawPtr(void *ptr, NbElts_t length) override final {
+        AIDGE_ASSERT(length >= mTensor.size(), "trying to set raw pointer of insufficient capacity");
+        mData = future_std::span<T>(static_cast<T *>(ptr), length);
+        mDataOwner.reset();
     };
 
-   private:
-    void lazyInit(void** data) {
-        if (*data == nullptr)
-            CHECK_CUDA_STATUS(cudaMalloc(data, mTensor.size() * sizeof(T)));
+private:
+    void lazyInit() {
+        AIDGE_INTERNAL_ASSERT(mTensor.dataType() == NativeType<T>::type);
+
+        if (mData.size() < mTensor.size()) {
+            // Need more data, a re-allocation will occur
+            AIDGE_ASSERT(mData.empty() || mDataOwner != nullptr, "trying to enlarge non-owned data");
+            mDataOwner.reset(cudaAlloc(mTensor.size()));
+            mData = future_std::span<T>(mDataOwner.get(), mTensor.size());
+        }
     }
 };
 
diff --git a/src/data/TensorImpl.cu b/src/data/TensorImpl.cu
index beb76f6..8ce0e2d 100644
--- a/src/data/TensorImpl.cu
+++ b/src/data/TensorImpl.cu
@@ -29,7 +29,7 @@ bool Aidge::TensorImpl_cuda<T>::operator==(const TensorImpl &otherImpl) const {
     if (mTensor.size() != otherImplCuda.mTensor.size())
         return false;
 
-    thrust::device_ptr<T> thrustData(mData);
-    thrust::device_ptr<T> thrustOtherData(otherImplCuda.mData);
+    thrust::device_ptr<T> thrustData(mData.data());
+    thrust::device_ptr<T> thrustOtherData(otherImplCuda.mData.data());
     return thrust::equal(thrustData, thrustData + mTensor.size(), thrustOtherData);
 }
-- 
GitLab