diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp index f8b1d60d75febd9ad569de2c2b32db8e7ef88541..a05b7710d7ff489d3f04e3a475a10ca42bd7c8c6 100644 --- a/include/aidge/backend/cuda/data/TensorImpl.hpp +++ b/include/aidge/backend/cuda/data/TensorImpl.hpp @@ -21,20 +21,23 @@ namespace Aidge { + template <typename SRC_T, typename DST_T> -typename std::enable_if<!std::is_same<half_float::half, DST_T>::value && !std::is_same<half_float::half, SRC_T>::value>::type -thrust_copy(const SRC_T* /*srcData*/, DST_T* /*dstData*/, size_t /*size*/); + std::enable_if_t< + !(std::is_same<half_float::half, DST_T>::value + || std::is_same<half_float::half, SRC_T>::value)> +thrust_copy(const SRC_T* srcData, DST_T* dstData, size_t size); -template <typename SRC_T, typename DST_T = half_float::half> -typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type -thrust_copy(const SRC_T *srcData, half_float::half *dstData, size_t size); +void +thrust_copy(const half_float::half* srcData, half_float::half* dstData, size_t size); -template <typename SRC_T = half_float::half, typename DST_T> -typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type +template <typename DST_T> +std::enable_if_t<!std::is_same<half_float::half, DST_T>::value> thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size); -template <typename SRC_T = half_float::half, typename DST_T = half_float::half> -void thrust_copy(const half_float::half* srcData, half_float::half* dstData, size_t size); +template <typename SRC_T> +std::enable_if_t<!std::is_same<half_float::half, SRC_T>::value> +thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size); /** * @brief Abstract class for the TensorImpl_cuda class template. diff --git a/src/data/TensorImpl.cu b/src/data/TensorImpl.cu index caa931e8dece35dbd66cee4d46b0ede1bb153714..f24e91ec50f389cf8d7bf4f59533a66a338dcc3c 100644 --- a/src/data/TensorImpl.cu +++ b/src/data/TensorImpl.cu @@ -14,17 +14,6 @@ #include <thrust/equal.h> #include <thrust/device_ptr.h> -template <typename SRC_T, typename DST_T> -typename std::enable_if<!std::is_same<half_float::half, DST_T>::value && - !std::is_same<half_float::half, SRC_T>::value> - ::type -Aidge::thrust_copy(const SRC_T* srcData, DST_T* dstData, size_t size) -{ - const thrust::device_ptr<const SRC_T> thrustSrcPtr(srcData); - thrust::device_ptr<DST_T> thrustDstPtr(dstData); - thrust::copy(thrustSrcPtr, thrustSrcPtr + size, thrustDstPtr); -} - template <typename SRC_T> __global__ void cudaCopyToH_kernel(const SRC_T* srcData, @@ -39,15 +28,6 @@ cudaCopyToH_kernel(const SRC_T* srcData, } } -template <typename SRC_T, typename DST_T> -typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type -Aidge::thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size) -{ - cudaCopyToH_kernel<SRC_T><<<(size + 255) / 256, 256>>> - (srcData, reinterpret_cast<__half*>(dstData), size); - CHECK_CUDA_STATUS(cudaPeekAtLastError()); -} - template <typename DST_T> __global__ void cudaCopyFromH_kernel(const __half* srcData, @@ -62,15 +42,6 @@ cudaCopyFromH_kernel(const __half* srcData, } } -template <typename SRC_T, typename DST_T> -typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type -Aidge::thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size) -{ - cudaCopyFromH_kernel<DST_T><<<(size + 255) / 256, 256>>> - (reinterpret_cast<const __half*>(srcData), dstData, size); - CHECK_CUDA_STATUS(cudaPeekAtLastError()); -} - __global__ void cudaCopyHToH_kernel(const __half* srcData, __half* dstData, @@ -84,14 +55,6 @@ cudaCopyHToH_kernel(const __half* srcData, } } -template <typename SRC_T, typename DST_T> -void Aidge::thrust_copy(const half_float::half* srcData, half_float::half* dstData, size_t size) -{ - cudaCopyHToH_kernel<<<(size + 255) / 256, 256>>> - (reinterpret_cast<const __half*>(srcData), reinterpret_cast<__half*>(dstData), size); - CHECK_CUDA_STATUS(cudaPeekAtLastError()); -} - template <class T> bool Aidge::TensorImpl_cuda<T>::operator==(const TensorImpl &otherImpl) const { const auto& otherImplCuda = static_cast<const TensorImpl_cuda<T>&>(otherImpl); @@ -104,135 +67,187 @@ bool Aidge::TensorImpl_cuda<T>::operator==(const TensorImpl &otherImpl) const { return thrust::equal(thrustData, thrustData + mNbElts, thrustOtherData); } +//// Thrust copy + +template <typename SRC_T, typename DST_T> +std::enable_if_t< + !(std::is_same<half_float::half, DST_T>::value + || std::is_same<half_float::half, SRC_T>::value)> +Aidge::thrust_copy(const SRC_T* srcData, DST_T* dstData, size_t size) +{ + const thrust::device_ptr<const SRC_T> thrustSrcPtr(srcData); + thrust::device_ptr<DST_T> thrustDstPtr(dstData); + thrust::copy(thrustSrcPtr, thrustSrcPtr + size, thrustDstPtr); +} + +void +Aidge::thrust_copy(const half_float::half* srcData, half_float::half* dstData, size_t size) +{ + cudaCopyHToH_kernel<<<(size + 255) / 256, 256>>> + (reinterpret_cast<const __half*>(srcData), reinterpret_cast<__half*>(dstData), size); + CHECK_CUDA_STATUS(cudaPeekAtLastError()); +} + +template <typename DST_T> +std::enable_if_t<!std::is_same<half_float::half, DST_T>::value> +Aidge::thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size) +{ + cudaCopyFromH_kernel<DST_T><<<(size + 255) / 256, 256>>> + (reinterpret_cast<const __half*>(srcData), dstData, size); + CHECK_CUDA_STATUS(cudaPeekAtLastError()); +} + +template <typename SRC_T> +std::enable_if_t<!std::is_same<half_float::half, SRC_T>::value> +Aidge::thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size) +{ + cudaCopyToH_kernel<SRC_T><<<(size + 255) / 256, 256>>> + (srcData, reinterpret_cast<__half*>(dstData), size); + CHECK_CUDA_STATUS(cudaPeekAtLastError()); +} + + +/// End Thrust copy + // double -template void Aidge::thrust_copy<>(double const*, double*, size_t); -template void Aidge::thrust_copy<>(double const*, float*, size_t); -template void Aidge::thrust_copy<>(double const*, half_float::half*, size_t); -template void Aidge::thrust_copy<>(double const*, int64_t*, size_t); -template void Aidge::thrust_copy<>(double const*, int32_t*, size_t); -template void Aidge::thrust_copy<>(double const*, int16_t*, size_t); -template void Aidge::thrust_copy<>(double const*, int8_t*, size_t); -template void Aidge::thrust_copy<>(double const*, uint64_t*, size_t); -template void Aidge::thrust_copy<>(double const*, uint32_t*, size_t); -template void Aidge::thrust_copy<>(double const*, uint16_t*, size_t); -template void Aidge::thrust_copy<>(double const*, uint8_t*, size_t); +template<> void Aidge::thrust_copy(double const*, double*, size_t); +template<> void Aidge::thrust_copy(double const*, float*, size_t); +template<> void Aidge::thrust_copy(double const*, half_float::half*, size_t); +template<> void Aidge::thrust_copy(double const*, int64_t*, size_t); +template<> void Aidge::thrust_copy(double const*, int32_t*, size_t); +template<> void Aidge::thrust_copy(double const*, int16_t*, size_t); +template<> void Aidge::thrust_copy(double const*, int8_t*, size_t); +template<> void Aidge::thrust_copy(double const*, uint64_t*, size_t); +template<> void Aidge::thrust_copy(double const*, uint32_t*, size_t); +template<> void Aidge::thrust_copy(double const*, uint16_t*, size_t); +template<> void Aidge::thrust_copy(double const*, uint8_t*, size_t); + // float -template void Aidge::thrust_copy<>(float const*, double*, size_t); -template void Aidge::thrust_copy<>(float const*, float*, size_t); -template void Aidge::thrust_copy<>(float const*, half_float::half*, size_t); -template void Aidge::thrust_copy<>(float const*, int64_t*, size_t); -template void Aidge::thrust_copy<>(float const*, int32_t*, size_t); -template void Aidge::thrust_copy<>(float const*, int16_t*, size_t); -template void Aidge::thrust_copy<>(float const*, int8_t*, size_t); -template void Aidge::thrust_copy<>(float const*, uint64_t*, size_t); -template void Aidge::thrust_copy<>(float const*, uint32_t*, size_t); -template void Aidge::thrust_copy<>(float const*, uint16_t*, size_t); -template void Aidge::thrust_copy<>(float const*, uint8_t*, size_t); +template<> void Aidge::thrust_copy(float const*, double*, size_t); +template<> void Aidge::thrust_copy(float const*, float*, size_t); +template<> void Aidge::thrust_copy(float const*, half_float::half*, size_t); +template<> void Aidge::thrust_copy(float const*, int64_t*, size_t); +template<> void Aidge::thrust_copy(float const*, int32_t*, size_t); +template<> void Aidge::thrust_copy(float const*, int16_t*, size_t); +template<> void Aidge::thrust_copy(float const*, int8_t*, size_t); +template<> void Aidge::thrust_copy(float const*, uint64_t*, size_t); +template<> void Aidge::thrust_copy(float const*, uint32_t*, size_t); +template<> void Aidge::thrust_copy(float const*, uint16_t*, size_t); +template<> void Aidge::thrust_copy(float const*, uint8_t*, size_t); + // half_float::half -template void Aidge::thrust_copy<>(half_float::half const*, double*, size_t); -template void Aidge::thrust_copy<>(half_float::half const*, float*, size_t); -template void Aidge::thrust_copy<>(half_float::half const*, half_float::half*, size_t); -template void Aidge::thrust_copy<>(half_float::half const*, int64_t*, size_t); -template void Aidge::thrust_copy<>(half_float::half const*, int32_t*, size_t); -template void Aidge::thrust_copy<>(half_float::half const*, int16_t*, size_t); -template void Aidge::thrust_copy<>(half_float::half const*, int8_t*, size_t); -template void Aidge::thrust_copy<>(half_float::half const*, uint64_t*, size_t); -template void Aidge::thrust_copy<>(half_float::half const*, uint32_t*, size_t); -template void Aidge::thrust_copy<>(half_float::half const*, uint16_t*, size_t); -template void Aidge::thrust_copy<>(half_float::half const*, uint8_t*, size_t); -// int64_t -template void Aidge::thrust_copy<>(int64_t const*, double*, size_t); -template void Aidge::thrust_copy<>(int64_t const*, float*, size_t); -template void Aidge::thrust_copy<>(int64_t const*, half_float::half*, size_t); -template void Aidge::thrust_copy<>(int64_t const*, int64_t*, size_t); -template void Aidge::thrust_copy<>(int64_t const*, int32_t*, size_t); -template void Aidge::thrust_copy<>(int64_t const*, int16_t*, size_t); -template void Aidge::thrust_copy<>(int64_t const*, int8_t*, size_t); -template void Aidge::thrust_copy<>(int64_t const*, uint64_t*, size_t); -template void Aidge::thrust_copy<>(int64_t const*, uint32_t*, size_t); -template void Aidge::thrust_copy<>(int64_t const*, uint16_t*, size_t); -template void Aidge::thrust_copy<>(int64_t const*, uint8_t*, size_t); -// int32_t -template void Aidge::thrust_copy<>(int32_t const*, double*, size_t); -template void Aidge::thrust_copy<>(int32_t const*, float*, size_t); -template void Aidge::thrust_copy<>(int32_t const*, half_float::half*, size_t); -template void Aidge::thrust_copy<>(int32_t const*, int64_t*, size_t); -template void Aidge::thrust_copy<>(int32_t const*, int32_t*, size_t); -template void Aidge::thrust_copy<>(int32_t const*, int16_t*, size_t); -template void Aidge::thrust_copy<>(int32_t const*, int8_t*, size_t); -template void Aidge::thrust_copy<>(int32_t const*, uint64_t*, size_t); -template void Aidge::thrust_copy<>(int32_t const*, uint32_t*, size_t); -template void Aidge::thrust_copy<>(int32_t const*, uint16_t*, size_t); -template void Aidge::thrust_copy<>(int32_t const*, uint8_t*, size_t); -// int16_t -template void Aidge::thrust_copy<>(int16_t const*, double*, size_t); -template void Aidge::thrust_copy<>(int16_t const*, float*, size_t); -template void Aidge::thrust_copy<>(int16_t const*, half_float::half*, size_t); -template void Aidge::thrust_copy<>(int16_t const*, int64_t*, size_t); -template void Aidge::thrust_copy<>(int16_t const*, int32_t*, size_t); -template void Aidge::thrust_copy<>(int16_t const*, int16_t*, size_t); -template void Aidge::thrust_copy<>(int16_t const*, int8_t*, size_t); -template void Aidge::thrust_copy<>(int16_t const*, uint64_t*, size_t); -template void Aidge::thrust_copy<>(int16_t const*, uint32_t*, size_t); -template void Aidge::thrust_copy<>(int16_t const*, uint16_t*, size_t); -template void Aidge::thrust_copy<>(int16_t const*, uint8_t*, size_t); -// int8_t -template void Aidge::thrust_copy<>(int8_t const*, double*, size_t); -template void Aidge::thrust_copy<>(int8_t const*, float*, size_t); -template void Aidge::thrust_copy<>(int8_t const*, half_float::half*, size_t); -template void Aidge::thrust_copy<>(int8_t const*, int64_t*, size_t); -template void Aidge::thrust_copy<>(int8_t const*, int32_t*, size_t); -template void Aidge::thrust_copy<>(int8_t const*, int16_t*, size_t); -template void Aidge::thrust_copy<>(int8_t const*, int8_t*, size_t); -template void Aidge::thrust_copy<>(int8_t const*, uint64_t*, size_t); -template void Aidge::thrust_copy<>(int8_t const*, uint32_t*, size_t); -template void Aidge::thrust_copy<>(int8_t const*, uint16_t*, size_t); -template void Aidge::thrust_copy<>(int8_t const*, uint8_t*, size_t); -// uint64_t -template void Aidge::thrust_copy<>(uint64_t const*, double*, size_t); -template void Aidge::thrust_copy<>(uint64_t const*, float*, size_t); -template void Aidge::thrust_copy<>(uint64_t const*, half_float::half*, size_t); -template void Aidge::thrust_copy<>(uint64_t const*, int64_t*, size_t); -template void Aidge::thrust_copy<>(uint64_t const*, int32_t*, size_t); -template void Aidge::thrust_copy<>(uint64_t const*, int16_t*, size_t); -template void Aidge::thrust_copy<>(uint64_t const*, int8_t*, size_t); -template void Aidge::thrust_copy<>(uint64_t const*, uint64_t*, size_t); -template void Aidge::thrust_copy<>(uint64_t const*, uint32_t*, size_t); -template void Aidge::thrust_copy<>(uint64_t const*, uint16_t*, size_t); -template void Aidge::thrust_copy<>(uint64_t const*, uint8_t*, size_t); -// uint32_t -template void Aidge::thrust_copy<>(uint32_t const*, double*, size_t); -template void Aidge::thrust_copy<>(uint32_t const*, float*, size_t); -template void Aidge::thrust_copy<>(uint32_t const*, half_float::half*, size_t); -template void Aidge::thrust_copy<>(uint32_t const*, int64_t*, size_t); -template void Aidge::thrust_copy<>(uint32_t const*, int32_t*, size_t); -template void Aidge::thrust_copy<>(uint32_t const*, int16_t*, size_t); -template void Aidge::thrust_copy<>(uint32_t const*, int8_t*, size_t); -template void Aidge::thrust_copy<>(uint32_t const*, uint64_t*, size_t); -template void Aidge::thrust_copy<>(uint32_t const*, uint32_t*, size_t); -template void Aidge::thrust_copy<>(uint32_t const*, uint16_t*, size_t); -template void Aidge::thrust_copy<>(uint32_t const*, uint8_t*, size_t); -// uint16_t -template void Aidge::thrust_copy<>(uint16_t const*, double*, size_t); -template void Aidge::thrust_copy<>(uint16_t const*, float*, size_t); -template void Aidge::thrust_copy<>(uint16_t const*, half_float::half*, size_t); -template void Aidge::thrust_copy<>(uint16_t const*, int64_t*, size_t); -template void Aidge::thrust_copy<>(uint16_t const*, int32_t*, size_t); -template void Aidge::thrust_copy<>(uint16_t const*, int16_t*, size_t); -template void Aidge::thrust_copy<>(uint16_t const*, int8_t*, size_t); -template void Aidge::thrust_copy<>(uint16_t const*, uint64_t*, size_t); -template void Aidge::thrust_copy<>(uint16_t const*, uint32_t*, size_t); -template void Aidge::thrust_copy<>(uint16_t const*, uint16_t*, size_t); -template void Aidge::thrust_copy<>(uint16_t const*, uint8_t*, size_t); -// uint8_t -template void Aidge::thrust_copy<>(uint8_t const*, double*, size_t); -template void Aidge::thrust_copy<>(uint8_t const*, float*, size_t); -template void Aidge::thrust_copy<>(uint8_t const*, half_float::half*, size_t); -template void Aidge::thrust_copy<>(uint8_t const*, int64_t*, size_t); -template void Aidge::thrust_copy<>(uint8_t const*, int32_t*, size_t); -template void Aidge::thrust_copy<>(uint8_t const*, int16_t*, size_t); -template void Aidge::thrust_copy<>(uint8_t const*, int8_t*, size_t); -template void Aidge::thrust_copy<>(uint8_t const*, uint64_t*, size_t); -template void Aidge::thrust_copy<>(uint8_t const*, uint32_t*, size_t); -template void Aidge::thrust_copy<>(uint8_t const*, uint16_t*, size_t); -template void Aidge::thrust_copy<>(uint8_t const*, uint8_t*, size_t); +template<> void Aidge::thrust_copy(half_float::half const*, double*, size_t); +template<> void Aidge::thrust_copy(half_float::half const*, float*, size_t); +//template<> void Aidge::thrust_copy(const half_float::half*, half_float::half*, size_t); +template<> void Aidge::thrust_copy(half_float::half const*, int64_t*, size_t); +template<> void Aidge::thrust_copy(half_float::half const*, int32_t*, size_t); +template<> void Aidge::thrust_copy(half_float::half const*, int16_t*, size_t); +template<> void Aidge::thrust_copy(half_float::half const*, int8_t*, size_t); +template<> void Aidge::thrust_copy(half_float::half const*, uint64_t*, size_t); +template<> void Aidge::thrust_copy(half_float::half const*, uint32_t*, size_t); +template<> void Aidge::thrust_copy(half_float::half const*, uint16_t*, size_t); +template<> void Aidge::thrust_copy(half_float::half const*, uint8_t*, size_t); + + // int64_t +template<> void Aidge::thrust_copy(int64_t const*, double*, size_t); +template<> void Aidge::thrust_copy(int64_t const*, float*, size_t); +template<> void Aidge::thrust_copy(int64_t const*, half_float::half*, size_t); +template<> void Aidge::thrust_copy(int64_t const*, int64_t*, size_t); +template<> void Aidge::thrust_copy(int64_t const*, int32_t*, size_t); +template<> void Aidge::thrust_copy(int64_t const*, int16_t*, size_t); +template<> void Aidge::thrust_copy(int64_t const*, int8_t*, size_t); +template<> void Aidge::thrust_copy(int64_t const*, uint64_t*, size_t); +template<> void Aidge::thrust_copy(int64_t const*, uint32_t*, size_t); +template<> void Aidge::thrust_copy(int64_t const*, uint16_t*, size_t); +template<> void Aidge::thrust_copy(int64_t const*, uint8_t*, size_t); + + // int32_t +template<> void Aidge::thrust_copy(int32_t const*, double*, size_t); +template<> void Aidge::thrust_copy(int32_t const*, float*, size_t); +template<> void Aidge::thrust_copy(int32_t const*, half_float::half*, size_t); +template<> void Aidge::thrust_copy(int32_t const*, int64_t*, size_t); +template<> void Aidge::thrust_copy(int32_t const*, int32_t*, size_t); +template<> void Aidge::thrust_copy(int32_t const*, int16_t*, size_t); +template<> void Aidge::thrust_copy(int32_t const*, int8_t*, size_t); +template<> void Aidge::thrust_copy(int32_t const*, uint64_t*, size_t); +template<> void Aidge::thrust_copy(int32_t const*, uint32_t*, size_t); +template<> void Aidge::thrust_copy(int32_t const*, uint16_t*, size_t); +template<> void Aidge::thrust_copy(int32_t const*, uint8_t*, size_t); + + // int16_t +template<> void Aidge::thrust_copy(int16_t const*, double*, size_t); +template<> void Aidge::thrust_copy(int16_t const*, float*, size_t); +template<> void Aidge::thrust_copy(int16_t const*, half_float::half*, size_t); +template<> void Aidge::thrust_copy(int16_t const*, int64_t*, size_t); +template<> void Aidge::thrust_copy(int16_t const*, int32_t*, size_t); +template<> void Aidge::thrust_copy(int16_t const*, int16_t*, size_t); +template<> void Aidge::thrust_copy(int16_t const*, int8_t*, size_t); +template<> void Aidge::thrust_copy(int16_t const*, uint64_t*, size_t); +template<> void Aidge::thrust_copy(int16_t const*, uint32_t*, size_t); +template<> void Aidge::thrust_copy(int16_t const*, uint16_t*, size_t); +template<> void Aidge::thrust_copy(int16_t const*, uint8_t*, size_t); + + // int8_t +template<> void Aidge::thrust_copy(int8_t const*, double*, size_t); +template<> void Aidge::thrust_copy(int8_t const*, float*, size_t); +template<> void Aidge::thrust_copy(int8_t const*, half_float::half*, size_t); +template<> void Aidge::thrust_copy(int8_t const*, int64_t*, size_t); +template<> void Aidge::thrust_copy(int8_t const*, int32_t*, size_t); +template<> void Aidge::thrust_copy(int8_t const*, int16_t*, size_t); +template<> void Aidge::thrust_copy(int8_t const*, int8_t*, size_t); +template<> void Aidge::thrust_copy(int8_t const*, uint64_t*, size_t); +template<> void Aidge::thrust_copy(int8_t const*, uint32_t*, size_t); +template<> void Aidge::thrust_copy(int8_t const*, uint16_t*, size_t); +template<> void Aidge::thrust_copy(int8_t const*, uint8_t*, size_t); + + // uint64_t +template<> void Aidge::thrust_copy(uint64_t const*, double*, size_t); +template<> void Aidge::thrust_copy(uint64_t const*, float*, size_t); +template<> void Aidge::thrust_copy(uint64_t const*, half_float::half*, size_t); +template<> void Aidge::thrust_copy(uint64_t const*, int64_t*, size_t); +template<> void Aidge::thrust_copy(uint64_t const*, int32_t*, size_t); +template<> void Aidge::thrust_copy(uint64_t const*, int16_t*, size_t); +template<> void Aidge::thrust_copy(uint64_t const*, int8_t*, size_t); +template<> void Aidge::thrust_copy(uint64_t const*, uint64_t*, size_t); +template<> void Aidge::thrust_copy(uint64_t const*, uint32_t*, size_t); +template<> void Aidge::thrust_copy(uint64_t const*, uint16_t*, size_t); +template<> void Aidge::thrust_copy(uint64_t const*, uint8_t*, size_t); + + // uint32_t +template<> void Aidge::thrust_copy(uint32_t const*, double*, size_t); +template<> void Aidge::thrust_copy(uint32_t const*, float*, size_t); +template<> void Aidge::thrust_copy(uint32_t const*, half_float::half*, size_t); +template<> void Aidge::thrust_copy(uint32_t const*, int64_t*, size_t); +template<> void Aidge::thrust_copy(uint32_t const*, int32_t*, size_t); +template<> void Aidge::thrust_copy(uint32_t const*, int16_t*, size_t); +template<> void Aidge::thrust_copy(uint32_t const*, int8_t*, size_t); +template<> void Aidge::thrust_copy(uint32_t const*, uint64_t*, size_t); +template<> void Aidge::thrust_copy(uint32_t const*, uint32_t*, size_t); +template<> void Aidge::thrust_copy(uint32_t const*, uint16_t*, size_t); +template<> void Aidge::thrust_copy(uint32_t const*, uint8_t*, size_t); + + // uint16_t +template<> void Aidge::thrust_copy(uint16_t const*, double*, size_t); +template<> void Aidge::thrust_copy(uint16_t const*, float*, size_t); +template<> void Aidge::thrust_copy(uint16_t const*, half_float::half*, size_t); +template<> void Aidge::thrust_copy(uint16_t const*, int64_t*, size_t); +template<> void Aidge::thrust_copy(uint16_t const*, int32_t*, size_t); +template<> void Aidge::thrust_copy(uint16_t const*, int16_t*, size_t); +template<> void Aidge::thrust_copy(uint16_t const*, int8_t*, size_t); +template<> void Aidge::thrust_copy(uint16_t const*, uint64_t*, size_t); +template<> void Aidge::thrust_copy(uint16_t const*, uint32_t*, size_t); +template<> void Aidge::thrust_copy(uint16_t const*, uint16_t*, size_t); +template<> void Aidge::thrust_copy(uint16_t const*, uint8_t*, size_t); + + // uint8_t +template<> void Aidge::thrust_copy(uint8_t const*, double*, size_t); +template<> void Aidge::thrust_copy(uint8_t const*, float*, size_t); +template<> void Aidge::thrust_copy(uint8_t const*, half_float::half*, size_t); +template<> void Aidge::thrust_copy(uint8_t const*, int64_t*, size_t); +template<> void Aidge::thrust_copy(uint8_t const*, int32_t*, size_t); +template<> void Aidge::thrust_copy(uint8_t const*, int16_t*, size_t); +template<> void Aidge::thrust_copy(uint8_t const*, int8_t*, size_t); +template<> void Aidge::thrust_copy(uint8_t const*, uint64_t*, size_t); +template<> void Aidge::thrust_copy(uint8_t const*, uint32_t*, size_t); +template<> void Aidge::thrust_copy(uint8_t const*, uint16_t*, size_t); +template<> void Aidge::thrust_copy(uint8_t const*, uint8_t*, size_t);