From 1992ac679e87174bd431832e1125722b59f097d6 Mon Sep 17 00:00:00 2001
From: Charles Villard <charles.villard@cea.fr>
Date: Thu, 10 Oct 2024 15:20:59 +0200
Subject: [PATCH] fix: TensorImpl: Enable if on return type instead of
 parameter

---
 include/aidge/backend/cuda/data/TensorImpl.hpp | 13 +++++++++----
 src/data/TensorImpl.cu                         | 12 +++++++-----
 2 files changed, 16 insertions(+), 9 deletions(-)

diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp
index 18a6a57..3a1499a 100644
--- a/include/aidge/backend/cuda/data/TensorImpl.hpp
+++ b/include/aidge/backend/cuda/data/TensorImpl.hpp
@@ -22,10 +22,15 @@ namespace Aidge {
 
 template <typename SRC_T, typename DST_T>
 void thrust_copy(const SRC_T* /*srcData*/, DST_T* /*dstData*/, size_t /*size*/);
-template <typename SRC_T, typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type* = nullptr>
-void thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size);
-template <typename DST_T, typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type* = nullptr>
-void thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size);
+
+template <typename SRC_T>
+typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type
+thrust_copy(const SRC_T *srcData, half_float::half *dstData, size_t size);
+
+template <typename DST_T>
+typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type
+thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size);
+
 template <>
 void thrust_copy(const half_float::half* srcData, half_float::half* dstData, size_t size);
 
diff --git a/src/data/TensorImpl.cu b/src/data/TensorImpl.cu
index 2084143..0af0fc1 100644
--- a/src/data/TensorImpl.cu
+++ b/src/data/TensorImpl.cu
@@ -36,8 +36,9 @@ cudaCopyToH_kernel(const SRC_T* srcData,
     }
 }
 
-template <typename SRC_T, typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type*>
-void Aidge::thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size)
+template <typename SRC_T>
+typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type
+Aidge::thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size)
 {
     cudaCopyToH_kernel<SRC_T><<<(size + 255) / 256, 256>>>
         (srcData, reinterpret_cast<__half*>(dstData), size);
@@ -58,8 +59,9 @@ cudaCopyFromH_kernel(const __half* srcData,
     }
 }
 
-template <typename DST_T, typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type*>
-void Aidge::thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size)
+template <typename DST_T>
+typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type
+Aidge::thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size)
 {
     cudaCopyFromH_kernel<DST_T><<<(size + 255) / 256, 256>>>
         (reinterpret_cast<const __half*>(srcData), dstData, size);
@@ -230,4 +232,4 @@ template void Aidge::thrust_copy<>(uint8_t const*, int8_t*, size_t);
 template void Aidge::thrust_copy<>(uint8_t const*, uint64_t*, size_t);
 template void Aidge::thrust_copy<>(uint8_t const*, uint32_t*, size_t);
 template void Aidge::thrust_copy<>(uint8_t const*, uint16_t*, size_t);
-template void Aidge::thrust_copy<>(uint8_t const*, uint8_t*, size_t);
\ No newline at end of file
+template void Aidge::thrust_copy<>(uint8_t const*, uint8_t*, size_t);
-- 
GitLab