From 39cf8659792837f85a9e0ac1b64b157a7adc76c6 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Thu, 7 Dec 2023 11:09:41 +0100 Subject: [PATCH] Added automatic GPU reallocation --- .../aidge/backend/cuda/data/TensorImpl.hpp | 57 +++++++++++++------ src/data/TensorImpl.cu | 4 +- 2 files changed, 41 insertions(+), 20 deletions(-) diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp index 80fc8d6..4f66a93 100644 --- a/include/aidge/backend/cuda/data/TensorImpl.hpp +++ b/include/aidge/backend/cuda/data/TensorImpl.hpp @@ -6,6 +6,7 @@ #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/future_std/span.hpp" #include "aidge/backend/cuda/utils/CudaUtils.hpp" #include "aidge/backend/cuda/utils/CudaContext.hpp" @@ -34,16 +35,31 @@ public: template <class T> class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_ { - private: +private: + static T* cudaAlloc(NbElts_t length) { + T* data; + CHECK_CUDA_STATUS(cudaMalloc(reinterpret_cast<void**>(&data), length * sizeof(T))); + return data; + } + + static void cudaDelete(T* data) { + // Should not be called if data is nullptr, according to the standard + cudaFree(data); + } + +private: const Tensor &mTensor; // Impl needs to access Tensor information, but is not // supposed to change it! - T* mData = nullptr; + /// Pointer to the data and its capacity + future_std::span<T> mData; + /// If this instance own the data, std::unique_ptr manages it + std::unique_ptr<T, decltype(&cudaDelete)> mDataOwner; mutable cudnnTensorDescriptor_t mCudnnTensor = nullptr; - public: +public: static constexpr const char *Backend = "cuda"; - TensorImpl_cuda(const Tensor &tensor) : TensorImpl(Backend), mTensor(tensor) {} + TensorImpl_cuda(const Tensor &tensor) : TensorImpl(Backend), mTensor(tensor), mDataOwner(nullptr, cudaDelete) {} bool operator==(const TensorImpl &otherImpl) const override final; @@ -52,7 +68,7 @@ class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_ { } // native interface - const T* data() const { return mData; } + const future_std::span<T>& data() const { return mData; } std::size_t scalarSize() const override { return sizeof(T); } @@ -133,13 +149,13 @@ class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_ { } void *rawPtr() override { - lazyInit(reinterpret_cast<void**>(&mData)); - return mData; + lazyInit(); + return mData.data(); }; const void *rawPtr() const override { - AIDGE_ASSERT(mData != nullptr, "accessing uninitialized const rawPtr"); - return mData; + AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const rawPtr"); + return mData.data(); }; void* getRaw(std::size_t idx) { @@ -180,21 +196,26 @@ class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_ { } virtual ~TensorImpl_cuda() { - if (mData != nullptr) - cudaFree(mData); - if (mCudnnTensor != nullptr) cudnnDestroyTensorDescriptor(mCudnnTensor); } - void setRawPtr(void* /*ptr*/) override final { - printf("Not implemented yet."); + void setRawPtr(void *ptr, NbElts_t length) override final { + AIDGE_ASSERT(length >= mTensor.size(), "trying to set raw pointer of insufficient capacity"); + mData = future_std::span<T>(static_cast<T *>(ptr), length); + mDataOwner.reset(); }; - private: - void lazyInit(void** data) { - if (*data == nullptr) - CHECK_CUDA_STATUS(cudaMalloc(data, mTensor.size() * sizeof(T))); +private: + void lazyInit() { + AIDGE_INTERNAL_ASSERT(mTensor.dataType() == NativeType<T>::type); + + if (mData.size() < mTensor.size()) { + // Need more data, a re-allocation will occur + AIDGE_ASSERT(mData.empty() || mDataOwner != nullptr, "trying to enlarge non-owned data"); + mDataOwner.reset(cudaAlloc(mTensor.size())); + mData = future_std::span<T>(mDataOwner.get(), mTensor.size()); + } } }; diff --git a/src/data/TensorImpl.cu b/src/data/TensorImpl.cu index beb76f6..8ce0e2d 100644 --- a/src/data/TensorImpl.cu +++ b/src/data/TensorImpl.cu @@ -29,7 +29,7 @@ bool Aidge::TensorImpl_cuda<T>::operator==(const TensorImpl &otherImpl) const { if (mTensor.size() != otherImplCuda.mTensor.size()) return false; - thrust::device_ptr<T> thrustData(mData); - thrust::device_ptr<T> thrustOtherData(otherImplCuda.mData); + thrust::device_ptr<T> thrustData(mData.data()); + thrust::device_ptr<T> thrustOtherData(otherImplCuda.mData.data()); return thrust::equal(thrustData, thrustData + mTensor.size(), thrustOtherData); } -- GitLab