Skip to content
Snippets Groups Projects
Commit 1c8e1637 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added thrust_copy specializations for half float

parent 10f6c6be
No related branches found
No related tags found
1 merge request!4Add Convert operator (a.k.a. Transmitter)
Pipeline #35594 failed
...@@ -14,7 +14,13 @@ ...@@ -14,7 +14,13 @@
namespace Aidge { namespace Aidge {
template <typename SRC_T, typename DST_T> template <typename SRC_T, typename DST_T>
void thrust_copy(SRC_T* /*srcData*/, DST_T* /*dstData*/, size_t /*size*/); void thrust_copy(const SRC_T* /*srcData*/, DST_T* /*dstData*/, size_t /*size*/);
template <typename SRC_T, typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type* = nullptr>
void thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size);
template <typename DST_T, typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type* = nullptr>
void thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size);
template <>
void thrust_copy(const half_float::half* srcData, half_float::half* dstData, size_t size);
/** /**
* @brief Abstract class for the TensorImpl_cuda class template. * @brief Abstract class for the TensorImpl_cuda class template.
......
...@@ -15,13 +15,78 @@ ...@@ -15,13 +15,78 @@
#include <thrust/device_ptr.h> #include <thrust/device_ptr.h>
template <typename SRC_T, typename DST_T> template <typename SRC_T, typename DST_T>
void Aidge::thrust_copy(SRC_T* srcData, DST_T* dstData, size_t size) void Aidge::thrust_copy(const SRC_T* srcData, DST_T* dstData, size_t size)
{ {
thrust::device_ptr<SRC_T> thrustSrcPtr(srcData); const thrust::device_ptr<const SRC_T> thrustSrcPtr(srcData);
thrust::device_ptr<DST_T> thrustDstPtr(dstData); thrust::device_ptr<DST_T> thrustDstPtr(dstData);
thrust::copy(thrustSrcPtr, thrustSrcPtr + size, thrustDstPtr); thrust::copy(thrustSrcPtr, thrustSrcPtr + size, thrustDstPtr);
} }
template <typename SRC_T>
__global__ void
cudaCopyToH_kernel(const SRC_T* srcData,
__half* dstData,
size_t size)
{
const unsigned int index = blockIdx.x * blockDim.x + threadIdx.x;
const unsigned int stride = blockDim.x * gridDim.x;
for (unsigned int i = index; i < size; i += stride) {
dstData[i] = __float2half(static_cast<float>(srcData[i]));
}
}
template <typename SRC_T, typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type* = nullptr>
void Aidge::thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size)
{
cudaCopyToH_kernel<SRC_T><<<(size + 255) / 256, 256>>>
(srcData, reinterpret_cast<__half*>(dstData), size);
CHECK_CUDA_STATUS(cudaPeekAtLastError());
}
template <typename DST_T>
__global__ void
cudaCopyFromH_kernel(const __half* srcData,
DST_T* dstData,
size_t size)
{
const unsigned int index = blockIdx.x * blockDim.x + threadIdx.x;
const unsigned int stride = blockDim.x * gridDim.x;
for (unsigned int i = index; i < size; i += stride) {
dstData[i] = static_cast<DST_T>(__half2float(srcData[i]));
}
}
template <typename DST_T, typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type* = nullptr>
void Aidge::thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size)
{
cudaCopyFromH_kernel<DST_T><<<(size + 255) / 256, 256>>>
(reinterpret_cast<const __half*>(srcData), dstData, size);
CHECK_CUDA_STATUS(cudaPeekAtLastError());
}
__global__ void
cudaCopyHToH_kernel(const __half* srcData,
__half* dstData,
size_t size)
{
const unsigned int index = blockIdx.x * blockDim.x + threadIdx.x;
const unsigned int stride = blockDim.x * gridDim.x;
for (unsigned int i = index; i < size; i += stride) {
dstData[i] = srcData[i];
}
}
template <>
void Aidge::thrust_copy(const half_float::half* srcData, half_float::half* dstData, size_t size)
{
cudaCopyHToH_kernel<<<(size + 255) / 256, 256>>>
(reinterpret_cast<const __half*>(srcData), reinterpret_cast<__half*>(dstData), size);
CHECK_CUDA_STATUS(cudaPeekAtLastError());
}
template <class T> template <class T>
bool Aidge::TensorImpl_cuda<T>::operator==(const TensorImpl &otherImpl) const { bool Aidge::TensorImpl_cuda<T>::operator==(const TensorImpl &otherImpl) const {
const auto& otherImplCuda = static_cast<const TensorImpl_cuda<T>&>(otherImpl); const auto& otherImplCuda = static_cast<const TensorImpl_cuda<T>&>(otherImpl);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment