Skip to content
Snippets Groups Projects
Commit 39cf8659 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added automatic GPU reallocation

parent ce1cbb42
No related branches found
No related tags found
1 merge request!4Add Convert operator (a.k.a. Transmitter)
Pipeline #35455 failed
......@@ -6,6 +6,7 @@
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/future_std/span.hpp"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
#include "aidge/backend/cuda/utils/CudaContext.hpp"
......@@ -34,16 +35,31 @@ public:
template <class T>
class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_ {
private:
private:
static T* cudaAlloc(NbElts_t length) {
T* data;
CHECK_CUDA_STATUS(cudaMalloc(reinterpret_cast<void**>(&data), length * sizeof(T)));
return data;
}
static void cudaDelete(T* data) {
// Should not be called if data is nullptr, according to the standard
cudaFree(data);
}
private:
const Tensor &mTensor; // Impl needs to access Tensor information, but is not
// supposed to change it!
T* mData = nullptr;
/// Pointer to the data and its capacity
future_std::span<T> mData;
/// If this instance own the data, std::unique_ptr manages it
std::unique_ptr<T, decltype(&cudaDelete)> mDataOwner;
mutable cudnnTensorDescriptor_t mCudnnTensor = nullptr;
public:
public:
static constexpr const char *Backend = "cuda";
TensorImpl_cuda(const Tensor &tensor) : TensorImpl(Backend), mTensor(tensor) {}
TensorImpl_cuda(const Tensor &tensor) : TensorImpl(Backend), mTensor(tensor), mDataOwner(nullptr, cudaDelete) {}
bool operator==(const TensorImpl &otherImpl) const override final;
......@@ -52,7 +68,7 @@ class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_ {
}
// native interface
const T* data() const { return mData; }
const future_std::span<T>& data() const { return mData; }
std::size_t scalarSize() const override { return sizeof(T); }
......@@ -133,13 +149,13 @@ class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_ {
}
void *rawPtr() override {
lazyInit(reinterpret_cast<void**>(&mData));
return mData;
lazyInit();
return mData.data();
};
const void *rawPtr() const override {
AIDGE_ASSERT(mData != nullptr, "accessing uninitialized const rawPtr");
return mData;
AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const rawPtr");
return mData.data();
};
void* getRaw(std::size_t idx) {
......@@ -180,21 +196,26 @@ class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_ {
}
virtual ~TensorImpl_cuda() {
if (mData != nullptr)
cudaFree(mData);
if (mCudnnTensor != nullptr)
cudnnDestroyTensorDescriptor(mCudnnTensor);
}
void setRawPtr(void* /*ptr*/) override final {
printf("Not implemented yet.");
void setRawPtr(void *ptr, NbElts_t length) override final {
AIDGE_ASSERT(length >= mTensor.size(), "trying to set raw pointer of insufficient capacity");
mData = future_std::span<T>(static_cast<T *>(ptr), length);
mDataOwner.reset();
};
private:
void lazyInit(void** data) {
if (*data == nullptr)
CHECK_CUDA_STATUS(cudaMalloc(data, mTensor.size() * sizeof(T)));
private:
void lazyInit() {
AIDGE_INTERNAL_ASSERT(mTensor.dataType() == NativeType<T>::type);
if (mData.size() < mTensor.size()) {
// Need more data, a re-allocation will occur
AIDGE_ASSERT(mData.empty() || mDataOwner != nullptr, "trying to enlarge non-owned data");
mDataOwner.reset(cudaAlloc(mTensor.size()));
mData = future_std::span<T>(mDataOwner.get(), mTensor.size());
}
}
};
......
......@@ -29,7 +29,7 @@ bool Aidge::TensorImpl_cuda<T>::operator==(const TensorImpl &otherImpl) const {
if (mTensor.size() != otherImplCuda.mTensor.size())
return false;
thrust::device_ptr<T> thrustData(mData);
thrust::device_ptr<T> thrustOtherData(otherImplCuda.mData);
thrust::device_ptr<T> thrustData(mData.data());
thrust::device_ptr<T> thrustOtherData(otherImplCuda.mData.data());
return thrust::equal(thrustData, thrustData + mTensor.size(), thrustOtherData);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment