diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp index 18a6a574f2b5970ab2b828cf69fe66165f52bbbb..3a1499a20915c356bc136367e0085a5824e7baba 100644 --- a/include/aidge/backend/cuda/data/TensorImpl.hpp +++ b/include/aidge/backend/cuda/data/TensorImpl.hpp @@ -22,10 +22,15 @@ namespace Aidge { template <typename SRC_T, typename DST_T> void thrust_copy(const SRC_T* /*srcData*/, DST_T* /*dstData*/, size_t /*size*/); -template <typename SRC_T, typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type* = nullptr> -void thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size); -template <typename DST_T, typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type* = nullptr> -void thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size); + +template <typename SRC_T> +typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type +thrust_copy(const SRC_T *srcData, half_float::half *dstData, size_t size); + +template <typename DST_T> +typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type +thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size); + template <> void thrust_copy(const half_float::half* srcData, half_float::half* dstData, size_t size); diff --git a/src/data/TensorImpl.cu b/src/data/TensorImpl.cu index 208414348f1f346c765cf8b97e919e3053513df0..0af0fc190f1d9c8ca09dec4dce28501886a1fe27 100644 --- a/src/data/TensorImpl.cu +++ b/src/data/TensorImpl.cu @@ -36,8 +36,9 @@ cudaCopyToH_kernel(const SRC_T* srcData, } } -template <typename SRC_T, typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type*> -void Aidge::thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size) +template <typename SRC_T> +typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type +Aidge::thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size) { cudaCopyToH_kernel<SRC_T><<<(size + 255) / 256, 256>>> (srcData, reinterpret_cast<__half*>(dstData), size); @@ -58,8 +59,9 @@ cudaCopyFromH_kernel(const __half* srcData, } } -template <typename DST_T, typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type*> -void Aidge::thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size) +template <typename DST_T> +typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type +Aidge::thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size) { cudaCopyFromH_kernel<DST_T><<<(size + 255) / 256, 256>>> (reinterpret_cast<const __half*>(srcData), dstData, size); @@ -230,4 +232,4 @@ template void Aidge::thrust_copy<>(uint8_t const*, int8_t*, size_t); template void Aidge::thrust_copy<>(uint8_t const*, uint64_t*, size_t); template void Aidge::thrust_copy<>(uint8_t const*, uint32_t*, size_t); template void Aidge::thrust_copy<>(uint8_t const*, uint16_t*, size_t); -template void Aidge::thrust_copy<>(uint8_t const*, uint8_t*, size_t); \ No newline at end of file +template void Aidge::thrust_copy<>(uint8_t const*, uint8_t*, size_t);