Forked from
Eclipse Projects / aidge / aidge_backend_cuda
51 commits behind the upstream repository.
-
Maxence Naud authoredMaxence Naud authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
TensorImpl.hpp 12.58 KiB
#ifndef AIDGE_BACKEND_CUDA_DATA_TENSORIMPL_H_
#define AIDGE_BACKEND_CUDA_DATA_TENSORIMPL_H_
#include <cstddef> // std::size_t
#include <memory>
#include <string>
#include <vector>
#include <cuda.h>
#include "aidge/backend/TensorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/future_std/span.hpp"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
#include "aidge/backend/cuda/utils/CudaContext.hpp"
namespace Aidge {
template <typename SRC_T, typename DST_T>
void thrust_copy(const SRC_T* /*srcData*/, DST_T* /*dstData*/, size_t /*size*/);
template <typename SRC_T, typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type* = nullptr>
void thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size);
template <typename DST_T, typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type* = nullptr>
void thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size);
template <>
void thrust_copy(const half_float::half* srcData, half_float::half* dstData, size_t size);
/**
* @brief Abstract class for the TensorImpl_cuda class template.
* @details Its purpose is to provide access to base methods that are specific
* to the implementation (which are therefore not present in the TensorImpl
* class), but whose data type does not need to be known.
*/
class TensorImpl_cuda_ {
protected:
mutable cudnnTensorDescriptor_t mCudnnTensor = nullptr;
public:
/**
* @brief Return the CuDNN tensor descriptor of the tensor.
* @details This method uses lazy initialization for the descriptor
* (which is therefore mutable in the derived class).
* @return cudnnTensorDescriptor_t CuDNN tensor descriptor.
*/
virtual const cudnnTensorDescriptor_t& getCudnnTensorDesc(const Tensor& tensor) const = 0;
virtual ~TensorImpl_cuda_() {
if (mCudnnTensor != nullptr)
cudnnDestroyTensorDescriptor(mCudnnTensor);
}
};
template <class T>
class TensorImpl_cuda : public TensorImpl, public TensorImpl_cuda_ {
private:
static T* cudaAlloc(NbElts_t length) {
T* data;
CHECK_CUDA_STATUS(cudaMalloc(reinterpret_cast<void**>(&data), length * sizeof(T)));
return data;
}
static void cudaDelete(T* data) {
// Should not be called if data is nullptr, according to the standard
cudaFree(data);
}
private:
future_std::span<T> mData;
/// If this instance own the data, std::unique_ptr manages it
std::unique_ptr<T, decltype(&cudaDelete)> mDataOwner;
public:
static const std::string Backend;
TensorImpl_cuda(DeviceIdx_t device, std::vector<DimSize_t> dims) : TensorImpl(Backend, device, dims), mDataOwner(nullptr, cudaDelete) {}
bool operator==(const TensorImpl &otherImpl) const override final;
static std::shared_ptr<TensorImpl_cuda> create(DeviceIdx_t device, std::vector<DimSize_t> dims) {
return std::make_shared<TensorImpl_cuda<T>>(device, dims);
}
// native interface
const future_std::span<T>& data() const { return mData; }
inline std::size_t capacity() const noexcept override { return mData.size(); }
std::size_t scalarSize() const noexcept override { return sizeof(T); }
void zeros() override final {
CHECK_CUDA_STATUS(cudaMemset(rawPtr(), T(0), mNbElts * sizeof(T)));
}
void copy(const void *src, NbElts_t length, NbElts_t offset = 0) override {
AIDGE_ASSERT(offset + length <= mNbElts, "TensorImpl_cuda<{}>::copy(): copy offset ({}) + length ({}) is above capacity ({})", typeid(T).name(), offset, length, mNbElts);
const T* srcT = static_cast<const T *>(src);
T* dstT = static_cast<T *>(rawPtr(offset));
AIDGE_ASSERT(dstT < srcT || dstT >= srcT + length, "TensorImpl_cuda<{}>::copy(): overlapping copy is not supported", typeid(T).name());
CHECK_CUDA_STATUS(cudaMemcpy(dstT, srcT, length * sizeof(T), cudaMemcpyDeviceToDevice));
}
void copyCast(const void *src, const DataType srcDt, NbElts_t length, NbElts_t offset = 0) override {
if (length == 0) {
return;
}
AIDGE_ASSERT(offset + length <= mNbElts, "TensorImpl_cuda<{}>::copyCast(): copy offset ({}) + length ({}) is above capacity ({})", typeid(T).name(), offset, length, mNbElts);
switch (srcDt) {
case DataType::Float64:
thrust_copy(static_cast<const double*>(src),
static_cast<T*>(rawPtr(offset)),
length);
break;
case DataType::Float32:
thrust_copy(static_cast<const float*>(src),
static_cast<T*>(rawPtr(offset)),
length);
break;
case DataType::Float16:
thrust_copy(static_cast<const half_float::half*>(src),
static_cast<T*>(rawPtr(offset)),
length);
break;
case DataType::Int64:
thrust_copy(static_cast<const int64_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
break;
case DataType::UInt64:
thrust_copy(static_cast<const uint64_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
break;
case DataType::Int32:
thrust_copy(static_cast<const int32_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
break;
case DataType::UInt32:
thrust_copy(static_cast<const uint32_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
break;
case DataType::Int16:
thrust_copy(static_cast<const int16_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
break;
case DataType::UInt16:
thrust_copy(static_cast<const uint16_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
break;
case DataType::Int8:
thrust_copy(static_cast<const int8_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
break;
case DataType::UInt8:
thrust_copy(static_cast<const uint8_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "TensorImpl_cuda<{}>::copyCast(): unsupported data type {}.", typeid(T).name(), srcDt);
break;
}
}
void copyFromDevice(const void *src, const std::pair<std::string, DeviceIdx_t>& device, NbElts_t length, NbElts_t offset = 0) override {
AIDGE_ASSERT(offset + length <= mNbElts, "TensorImpl_cuda<{}>::copyFromDevice(): copy offset ({}) + length ({}) is above capacity ({})", typeid(T).name(), offset, length, mNbElts);
CHECK_CUDA_STATUS(cudaMemcpy(rawPtr(offset), src, length * sizeof(T), cudaMemcpyDeviceToDevice));
}
void copyFromHost(const void *src, NbElts_t length, NbElts_t offset = 0) override {
AIDGE_ASSERT(offset + length <= mNbElts, "TensorImpl_cuda<{}>::copyFromHost(): copy offset ({}) + length ({}) is above capacity ({})", typeid(T).name(), offset, length, mNbElts);
CHECK_CUDA_STATUS(cudaMemcpy(rawPtr(offset), src, length * sizeof(T), cudaMemcpyHostToDevice));
}
void copyToHost(void *dst, NbElts_t length, NbElts_t offset = 0) const override {
AIDGE_ASSERT(offset + length <= mNbElts, "TensorImpl_cuda<{}>::copyToHost(): copy offset ({}) + length ({}) is above capacity ({})", typeid(T).name(), offset, length, mNbElts);
CHECK_CUDA_STATUS(cudaMemcpy(dst, rawPtr(offset), length * sizeof(T), cudaMemcpyDeviceToHost));
}
void *rawPtr(NbElts_t offset = 0) override {
lazyInit();
return (mData.data() + offset);
};
const void *rawPtr(NbElts_t offset = 0) const override {
AIDGE_ASSERT(mData.size() >= mNbElts, "TensorImpl_cuda<{}>::rawPtr(): accessing uninitialized const rawPtr", typeid(T).name());
return (mData.data() + offset);
};
const cudnnTensorDescriptor_t& getCudnnTensorDesc(const Tensor& tensor) const override {
if (mCudnnTensor == nullptr) {
CHECK_CUDNN_STATUS(cudnnCreateTensorDescriptor(&mCudnnTensor));
if (tensor.size() > 0) {
/**
** cudNN Tensors are restricted to having at least 4 dimensions :
** When working with lower dimensionsal data, unused dimensions are set to 1.
** Referes to the cudnnSetTensorNdDescriptor documentation from :
** https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html
**/
std::vector<int> dims(tensor.dims().cbegin(), tensor.dims().cend());
std::vector<int> strides(tensor.strides().cbegin(), tensor.strides().cend());
if (dims.size() < 4) {
dims.resize(4, 1);
strides.resize(4, 1);
}
CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(mCudnnTensor,
CudaContext::data_type<T>::value,
dims.size(),
&dims[0],
&strides[0]));
}
}
else {
// Compare if the shape of the tensor has changed
cudnnDataType_t currentDataType;
int currentNbDims;
// Since we don't know the nb dims of the current tensor, we init with CUDNN_DIM_MAX then remove the trailing zeros
std::vector<int> currentDims(CUDNN_DIM_MAX);
std::vector<int> currentStrides(CUDNN_DIM_MAX);
CHECK_CUDNN_STATUS(cudnnGetTensorNdDescriptor(mCudnnTensor, CUDNN_DIM_MAX, ¤tDataType, ¤tNbDims, currentDims.data(), currentStrides.data()));
// Remove the trailing zeros
currentDims.erase(std::find_if(currentDims.rbegin(), currentDims.rend(), [](int x) { return x != 0; }).base(),
currentDims.end());
std::vector<int> dims(tensor.dims().cbegin(), tensor.dims().cend());
if (dims.size() < 4) {
dims.resize(4, 1);
}
// Update descriptor if shape has changed
if (dims!=currentDims) {
std::vector<int> strides(tensor.strides().cbegin(), tensor.strides().cend());
if (strides.size() < 4) {
strides.resize(4, 1);
}
CHECK_CUDNN_STATUS(cudnnSetTensorNdDescriptor(mCudnnTensor,
CudaContext::data_type<T>::value,
dims.size(),
&dims[0],
&strides[0]));
}
}
return mCudnnTensor;
}
void setRawPtr(void *ptr, NbElts_t length) override final {
AIDGE_ASSERT(length >= mNbElts, "TensorImpl_cuda<{}>::setRawPtr(): trying to set raw pointer (length: {}) of insufficient capacity (required: {})", typeid(T).name(), length, mNbElts);
mData = future_std::span<T>(static_cast<T *>(ptr), length);
mDataOwner.reset();
};
virtual ~TensorImpl_cuda() = default;
private:
void lazyInit() {
if (mData.size() < mNbElts) {
// Need more data, a re-allocation will occur
AIDGE_ASSERT(mData.empty() || mDataOwner != nullptr, "TensorImpl_cuda<{}>: trying to enlarge non-owned data", typeid(T).name());
mDataOwner.reset(cudaAlloc(mNbElts));
mData = future_std::span<T>(mDataOwner.get(), mNbElts);
}
}
};
template <typename T>
const std::string TensorImpl_cuda<T>::Backend = "cuda";
namespace {
static Registrar<Tensor> registrarTensorImpl_cuda_Float64(
{"cuda", DataType::Float64}, Aidge::TensorImpl_cuda<double>::create);
static Registrar<Tensor> registrarTensorImpl_cuda_Float32(
{"cuda", DataType::Float32}, Aidge::TensorImpl_cuda<float>::create);
static Registrar<Tensor> registrarTensorImpl_cuda_Float16(
{"cuda", DataType::Float16}, Aidge::TensorImpl_cuda<half_float::half>::create);
static Registrar<Tensor> registrarTensorImpl_cuda_Int32(
{"cuda", DataType::Int32}, Aidge::TensorImpl_cuda<int32_t>::create);
} // namespace
} // namespace Aidge
#endif /* AIDGE_BACKEND_CUDA_DATA_TENSORIMPL_H_ */