diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp index 176624b0df9f0a90eba5d8afe9496125f2aad898..88e67d0fb5987ee10ff8db7faa66ae4202aae9fc 100644 --- a/include/aidge/backend/cuda/data/TensorImpl.hpp +++ b/include/aidge/backend/cuda/data/TensorImpl.hpp @@ -14,7 +14,13 @@ namespace Aidge { template <typename SRC_T, typename DST_T> -void thrust_copy(SRC_T* /*srcData*/, DST_T* /*dstData*/, size_t /*size*/); +void thrust_copy(const SRC_T* /*srcData*/, DST_T* /*dstData*/, size_t /*size*/); +template <typename SRC_T, typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type* = nullptr> +void thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size); +template <typename DST_T, typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type* = nullptr> +void thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size); +template <> +void thrust_copy(const half_float::half* srcData, half_float::half* dstData, size_t size); /** * @brief Abstract class for the TensorImpl_cuda class template. diff --git a/src/data/TensorImpl.cu b/src/data/TensorImpl.cu index 8ce0e2d6d3430ed40c32bba401e362a56e540ff5..ecacd4d678dd7d79462332fb28e238b063d8bdd1 100644 --- a/src/data/TensorImpl.cu +++ b/src/data/TensorImpl.cu @@ -15,13 +15,78 @@ #include <thrust/device_ptr.h> template <typename SRC_T, typename DST_T> -void Aidge::thrust_copy(SRC_T* srcData, DST_T* dstData, size_t size) +void Aidge::thrust_copy(const SRC_T* srcData, DST_T* dstData, size_t size) { - thrust::device_ptr<SRC_T> thrustSrcPtr(srcData); + const thrust::device_ptr<const SRC_T> thrustSrcPtr(srcData); thrust::device_ptr<DST_T> thrustDstPtr(dstData); thrust::copy(thrustSrcPtr, thrustSrcPtr + size, thrustDstPtr); } +template <typename SRC_T> +__global__ void +cudaCopyToH_kernel(const SRC_T* srcData, + __half* dstData, + size_t size) +{ + const unsigned int index = blockIdx.x * blockDim.x + threadIdx.x; + const unsigned int stride = blockDim.x * gridDim.x; + + for (unsigned int i = index; i < size; i += stride) { + dstData[i] = __float2half(static_cast<float>(srcData[i])); + } +} + +template <typename SRC_T, typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type* = nullptr> +void Aidge::thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size) +{ + cudaCopyToH_kernel<SRC_T><<<(size + 255) / 256, 256>>> + (srcData, reinterpret_cast<__half*>(dstData), size); + CHECK_CUDA_STATUS(cudaPeekAtLastError()); +} + +template <typename DST_T> +__global__ void +cudaCopyFromH_kernel(const __half* srcData, + DST_T* dstData, + size_t size) +{ + const unsigned int index = blockIdx.x * blockDim.x + threadIdx.x; + const unsigned int stride = blockDim.x * gridDim.x; + + for (unsigned int i = index; i < size; i += stride) { + dstData[i] = static_cast<DST_T>(__half2float(srcData[i])); + } +} + +template <typename DST_T, typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type* = nullptr> +void Aidge::thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size) +{ + cudaCopyFromH_kernel<DST_T><<<(size + 255) / 256, 256>>> + (reinterpret_cast<const __half*>(srcData), dstData, size); + CHECK_CUDA_STATUS(cudaPeekAtLastError()); +} + +__global__ void +cudaCopyHToH_kernel(const __half* srcData, + __half* dstData, + size_t size) +{ + const unsigned int index = blockIdx.x * blockDim.x + threadIdx.x; + const unsigned int stride = blockDim.x * gridDim.x; + + for (unsigned int i = index; i < size; i += stride) { + dstData[i] = srcData[i]; + } +} + +template <> +void Aidge::thrust_copy(const half_float::half* srcData, half_float::half* dstData, size_t size) +{ + cudaCopyHToH_kernel<<<(size + 255) / 256, 256>>> + (reinterpret_cast<const __half*>(srcData), reinterpret_cast<__half*>(dstData), size); + CHECK_CUDA_STATUS(cudaPeekAtLastError()); +} + template <class T> bool Aidge::TensorImpl_cuda<T>::operator==(const TensorImpl &otherImpl) const { const auto& otherImplCuda = static_cast<const TensorImpl_cuda<T>&>(otherImpl);