From 1992ac679e87174bd431832e1125722b59f097d6 Mon Sep 17 00:00:00 2001 From: Charles Villard <charles.villard@cea.fr> Date: Thu, 10 Oct 2024 15:20:59 +0200 Subject: [PATCH] fix: TensorImpl: Enable if on return type instead of parameter --- include/aidge/backend/cuda/data/TensorImpl.hpp | 13 +++++++++---- src/data/TensorImpl.cu | 12 +++++++----- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp index 18a6a57..3a1499a 100644 --- a/include/aidge/backend/cuda/data/TensorImpl.hpp +++ b/include/aidge/backend/cuda/data/TensorImpl.hpp @@ -22,10 +22,15 @@ namespace Aidge { template <typename SRC_T, typename DST_T> void thrust_copy(const SRC_T* /*srcData*/, DST_T* /*dstData*/, size_t /*size*/); -template <typename SRC_T, typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type* = nullptr> -void thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size); -template <typename DST_T, typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type* = nullptr> -void thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size); + +template <typename SRC_T> +typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type +thrust_copy(const SRC_T *srcData, half_float::half *dstData, size_t size); + +template <typename DST_T> +typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type +thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size); + template <> void thrust_copy(const half_float::half* srcData, half_float::half* dstData, size_t size); diff --git a/src/data/TensorImpl.cu b/src/data/TensorImpl.cu index 2084143..0af0fc1 100644 --- a/src/data/TensorImpl.cu +++ b/src/data/TensorImpl.cu @@ -36,8 +36,9 @@ cudaCopyToH_kernel(const SRC_T* srcData, } } -template <typename SRC_T, typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type*> -void Aidge::thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size) +template <typename SRC_T> +typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type +Aidge::thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size) { cudaCopyToH_kernel<SRC_T><<<(size + 255) / 256, 256>>> (srcData, reinterpret_cast<__half*>(dstData), size); @@ -58,8 +59,9 @@ cudaCopyFromH_kernel(const __half* srcData, } } -template <typename DST_T, typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type*> -void Aidge::thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size) +template <typename DST_T> +typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type +Aidge::thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size) { cudaCopyFromH_kernel<DST_T><<<(size + 255) / 256, 256>>> (reinterpret_cast<const __half*>(srcData), dstData, size); @@ -230,4 +232,4 @@ template void Aidge::thrust_copy<>(uint8_t const*, int8_t*, size_t); template void Aidge::thrust_copy<>(uint8_t const*, uint64_t*, size_t); template void Aidge::thrust_copy<>(uint8_t const*, uint32_t*, size_t); template void Aidge::thrust_copy<>(uint8_t const*, uint16_t*, size_t); -template void Aidge::thrust_copy<>(uint8_t const*, uint8_t*, size_t); \ No newline at end of file +template void Aidge::thrust_copy<>(uint8_t const*, uint8_t*, size_t); -- GitLab