diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp index 80fc8d64700f0991cb73e3dc051ecf0f156ece82..4f66a9321e46c11aa47e7d4b011b2d6a4f0e65a7 100644 --- a/include/aidge/backend/cuda/data/TensorImpl.hpp +++ b/include/aidge/backend/cuda/data/TensorImpl.hpp @@ -6,6 +6,7 @@ #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/future_std/span.hpp" #include "aidge/backend/cuda/utils/CudaUtils.hpp" #include "aidge/backend/cuda/utils/CudaContext.hpp" @@ -34,16 +35,31 @@ public: template <class T> class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_ { - private: +private: + static T* cudaAlloc(NbElts_t length) { + T* data; + CHECK_CUDA_STATUS(cudaMalloc(reinterpret_cast<void**>(&data), length * sizeof(T))); + return data; + } + + static void cudaDelete(T* data) { + // Should not be called if data is nullptr, according to the standard + cudaFree(data); + } + +private: const Tensor &mTensor; // Impl needs to access Tensor information, but is not // supposed to change it! - T* mData = nullptr; + /// Pointer to the data and its capacity + future_std::span<T> mData; + /// If this instance own the data, std::unique_ptr manages it + std::unique_ptr<T, decltype(&cudaDelete)> mDataOwner; mutable cudnnTensorDescriptor_t mCudnnTensor = nullptr; - public: +public: static constexpr const char *Backend = "cuda"; - TensorImpl_cuda(const Tensor &tensor) : TensorImpl(Backend), mTensor(tensor) {} + TensorImpl_cuda(const Tensor &tensor) : TensorImpl(Backend), mTensor(tensor), mDataOwner(nullptr, cudaDelete) {} bool operator==(const TensorImpl &otherImpl) const override final; @@ -52,7 +68,7 @@ class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_ { } // native interface - const T* data() const { return mData; } + const future_std::span<T>& data() const { return mData; } std::size_t scalarSize() const override { return sizeof(T); } @@ -133,13 +149,13 @@ class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_ { } void *rawPtr() override { - lazyInit(reinterpret_cast<void**>(&mData)); - return mData; + lazyInit(); + return mData.data(); }; const void *rawPtr() const override { - AIDGE_ASSERT(mData != nullptr, "accessing uninitialized const rawPtr"); - return mData; + AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const rawPtr"); + return mData.data(); }; void* getRaw(std::size_t idx) { @@ -180,21 +196,26 @@ class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_ { } virtual ~TensorImpl_cuda() { - if (mData != nullptr) - cudaFree(mData); - if (mCudnnTensor != nullptr) cudnnDestroyTensorDescriptor(mCudnnTensor); } - void setRawPtr(void* /*ptr*/) override final { - printf("Not implemented yet."); + void setRawPtr(void *ptr, NbElts_t length) override final { + AIDGE_ASSERT(length >= mTensor.size(), "trying to set raw pointer of insufficient capacity"); + mData = future_std::span<T>(static_cast<T *>(ptr), length); + mDataOwner.reset(); }; - private: - void lazyInit(void** data) { - if (*data == nullptr) - CHECK_CUDA_STATUS(cudaMalloc(data, mTensor.size() * sizeof(T))); +private: + void lazyInit() { + AIDGE_INTERNAL_ASSERT(mTensor.dataType() == NativeType<T>::type); + + if (mData.size() < mTensor.size()) { + // Need more data, a re-allocation will occur + AIDGE_ASSERT(mData.empty() || mDataOwner != nullptr, "trying to enlarge non-owned data"); + mDataOwner.reset(cudaAlloc(mTensor.size())); + mData = future_std::span<T>(mDataOwner.get(), mTensor.size()); + } } }; diff --git a/src/data/TensorImpl.cu b/src/data/TensorImpl.cu index beb76f627c8a63410393d60e09cefaaaeaf6220d..8ce0e2d6d3430ed40c32bba401e362a56e540ff5 100644 --- a/src/data/TensorImpl.cu +++ b/src/data/TensorImpl.cu @@ -29,7 +29,7 @@ bool Aidge::TensorImpl_cuda<T>::operator==(const TensorImpl &otherImpl) const { if (mTensor.size() != otherImplCuda.mTensor.size()) return false; - thrust::device_ptr<T> thrustData(mData); - thrust::device_ptr<T> thrustOtherData(otherImplCuda.mData); + thrust::device_ptr<T> thrustData(mData.data()); + thrust::device_ptr<T> thrustOtherData(otherImplCuda.mData.data()); return thrust::equal(thrustData, thrustData + mTensor.size(), thrustOtherData); }