Skip to content
Snippets Groups Projects
Commit 1562648a authored by Charles Villard's avatar Charles Villard Committed by Maxence Naud
Browse files

edit: TensorImpl: gave up original approach to use class partial

specialization because it is better handled by compilers
parent 29181ce4
No related branches found
No related tags found
2 merge requests!75Update 0.5.1 -> 0.6.0,!40Fix template compilation and warnings
Pipeline #70803 passed
......@@ -23,21 +23,7 @@ namespace Aidge {
template <typename SRC_T, typename DST_T>
std::enable_if_t<
!(std::is_same<half_float::half, DST_T>::value
|| std::is_same<half_float::half, SRC_T>::value)>
thrust_copy(const SRC_T* srcData, DST_T* dstData, size_t size);
void
thrust_copy(const half_float::half* srcData, half_float::half* dstData, size_t size);
template <typename DST_T>
std::enable_if_t<!std::is_same<half_float::half, DST_T>::value>
thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size);
template <typename SRC_T>
std::enable_if_t<!std::is_same<half_float::half, SRC_T>::value>
thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size);
void thrust_copy(const SRC_T* srcData, DST_T* dstData, size_t size);
/**
* @brief Abstract class for the TensorImpl_cuda class template.
......
......@@ -69,185 +69,192 @@ bool Aidge::TensorImpl_cuda<T>::operator==(const TensorImpl &otherImpl) const {
//// Thrust copy
// General template for thrust_copy (this will handle types other than half_float::half)
template <typename SRC_T, typename DST_T>
std::enable_if_t<
!(std::is_same<half_float::half, DST_T>::value
|| std::is_same<half_float::half, SRC_T>::value)>
Aidge::thrust_copy(const SRC_T* srcData, DST_T* dstData, size_t size)
{
const thrust::device_ptr<const SRC_T> thrustSrcPtr(srcData);
thrust::device_ptr<DST_T> thrustDstPtr(dstData);
thrust::copy(thrustSrcPtr, thrustSrcPtr + size, thrustDstPtr);
}
struct ThrustCopy {
static void copy(const SRC_T* srcData, DST_T* dstData, size_t size) {
const thrust::device_ptr<const SRC_T> thrustSrcPtr(srcData);
thrust::device_ptr<DST_T> thrustDstPtr(dstData);
thrust::copy(thrustSrcPtr, thrustSrcPtr + size, thrustDstPtr);
}
};
void
Aidge::thrust_copy(const half_float::half* srcData, half_float::half* dstData, size_t size)
{
cudaCopyHToH_kernel<<<(size + 255) / 256, 256>>>
(reinterpret_cast<const __half*>(srcData), reinterpret_cast<__half*>(dstData), size);
CHECK_CUDA_STATUS(cudaPeekAtLastError());
}
// Specialization for half_float::half, half_float::half
template <>
struct ThrustCopy<half_float::half, half_float::half> {
static void copy(const half_float::half* srcData, half_float::half* dstData, size_t size) {
cudaCopyHToH_kernel<<<(size + 255) / 256, 256>>>
(reinterpret_cast<const __half*>(srcData), reinterpret_cast<__half*>(dstData), size);
CHECK_CUDA_STATUS(cudaPeekAtLastError());
}
};
// Specialization for half_float::half, DST_T (where DST_T is not half_float::half)
template <typename DST_T>
std::enable_if_t<!std::is_same<half_float::half, DST_T>::value>
Aidge::thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size)
{
cudaCopyFromH_kernel<DST_T><<<(size + 255) / 256, 256>>>
(reinterpret_cast<const __half*>(srcData), dstData, size);
CHECK_CUDA_STATUS(cudaPeekAtLastError());
}
struct ThrustCopy<half_float::half, DST_T> {
static void copy(const half_float::half* srcData, DST_T* dstData, size_t size) {
cudaCopyFromH_kernel<DST_T><<<(size + 255) / 256, 256>>>
(reinterpret_cast<const __half*>(srcData), dstData, size);
CHECK_CUDA_STATUS(cudaPeekAtLastError());
}
};
// Specialization for SRC_T, half_float::half (where SRC_T is not half_float::half)
template <typename SRC_T>
std::enable_if_t<!std::is_same<half_float::half, SRC_T>::value>
Aidge::thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size)
{
cudaCopyToH_kernel<SRC_T><<<(size + 255) / 256, 256>>>
(srcData, reinterpret_cast<__half*>(dstData), size);
CHECK_CUDA_STATUS(cudaPeekAtLastError());
}
struct ThrustCopy<SRC_T, half_float::half> {
static void copy(const SRC_T* srcData, half_float::half* dstData, size_t size) {
cudaCopyToH_kernel<SRC_T><<<(size + 255) / 256, 256>>>
(srcData, reinterpret_cast<__half*>(dstData), size);
CHECK_CUDA_STATUS(cudaPeekAtLastError());
}
};
template <typename SRC_T, typename DST_T>
void Aidge::thrust_copy(const SRC_T* srcData, DST_T* dstData, size_t size) {
ThrustCopy<SRC_T, DST_T>::copy(srcData, dstData, size);
}
/// End Thrust copy
// double
template<> void Aidge::thrust_copy(double const*, double*, size_t);
template<> void Aidge::thrust_copy(double const*, float*, size_t);
template<> void Aidge::thrust_copy(double const*, half_float::half*, size_t);
template<> void Aidge::thrust_copy(double const*, int64_t*, size_t);
template<> void Aidge::thrust_copy(double const*, int32_t*, size_t);
template<> void Aidge::thrust_copy(double const*, int16_t*, size_t);
template<> void Aidge::thrust_copy(double const*, int8_t*, size_t);
template<> void Aidge::thrust_copy(double const*, uint64_t*, size_t);
template<> void Aidge::thrust_copy(double const*, uint32_t*, size_t);
template<> void Aidge::thrust_copy(double const*, uint16_t*, size_t);
template<> void Aidge::thrust_copy(double const*, uint8_t*, size_t);
template void Aidge::thrust_copy(double const*, double*, size_t);
template void Aidge::thrust_copy(double const*, float*, size_t);
template void Aidge::thrust_copy(double const*, half_float::half*, size_t);
template void Aidge::thrust_copy(double const*, int64_t*, size_t);
template void Aidge::thrust_copy(double const*, int32_t*, size_t);
template void Aidge::thrust_copy(double const*, int16_t*, size_t);
template void Aidge::thrust_copy(double const*, int8_t*, size_t);
template void Aidge::thrust_copy(double const*, uint64_t*, size_t);
template void Aidge::thrust_copy(double const*, uint32_t*, size_t);
template void Aidge::thrust_copy(double const*, uint16_t*, size_t);
template void Aidge::thrust_copy(double const*, uint8_t*, size_t);
// float
template<> void Aidge::thrust_copy(float const*, double*, size_t);
template<> void Aidge::thrust_copy(float const*, float*, size_t);
template<> void Aidge::thrust_copy(float const*, half_float::half*, size_t);
template<> void Aidge::thrust_copy(float const*, int64_t*, size_t);
template<> void Aidge::thrust_copy(float const*, int32_t*, size_t);
template<> void Aidge::thrust_copy(float const*, int16_t*, size_t);
template<> void Aidge::thrust_copy(float const*, int8_t*, size_t);
template<> void Aidge::thrust_copy(float const*, uint64_t*, size_t);
template<> void Aidge::thrust_copy(float const*, uint32_t*, size_t);
template<> void Aidge::thrust_copy(float const*, uint16_t*, size_t);
template<> void Aidge::thrust_copy(float const*, uint8_t*, size_t);
template void Aidge::thrust_copy(float const*, double*, size_t);
template void Aidge::thrust_copy(float const*, float*, size_t);
template void Aidge::thrust_copy(float const*, half_float::half*, size_t);
template void Aidge::thrust_copy(float const*, int64_t*, size_t);
template void Aidge::thrust_copy(float const*, int32_t*, size_t);
template void Aidge::thrust_copy(float const*, int16_t*, size_t);
template void Aidge::thrust_copy(float const*, int8_t*, size_t);
template void Aidge::thrust_copy(float const*, uint64_t*, size_t);
template void Aidge::thrust_copy(float const*, uint32_t*, size_t);
template void Aidge::thrust_copy(float const*, uint16_t*, size_t);
template void Aidge::thrust_copy(float const*, uint8_t*, size_t);
// half_float::half
template<> void Aidge::thrust_copy(half_float::half const*, double*, size_t);
template<> void Aidge::thrust_copy(half_float::half const*, float*, size_t);
//template<> void Aidge::thrust_copy(const half_float::half*, half_float::half*, size_t);
template<> void Aidge::thrust_copy(half_float::half const*, int64_t*, size_t);
template<> void Aidge::thrust_copy(half_float::half const*, int32_t*, size_t);
template<> void Aidge::thrust_copy(half_float::half const*, int16_t*, size_t);
template<> void Aidge::thrust_copy(half_float::half const*, int8_t*, size_t);
template<> void Aidge::thrust_copy(half_float::half const*, uint64_t*, size_t);
template<> void Aidge::thrust_copy(half_float::half const*, uint32_t*, size_t);
template<> void Aidge::thrust_copy(half_float::half const*, uint16_t*, size_t);
template<> void Aidge::thrust_copy(half_float::half const*, uint8_t*, size_t);
template void Aidge::thrust_copy(half_float::half const*, double*, size_t);
template void Aidge::thrust_copy(half_float::half const*, float*, size_t);
template void Aidge::thrust_copy(const half_float::half*, half_float::half*, size_t);
template void Aidge::thrust_copy(half_float::half const*, int64_t*, size_t);
template void Aidge::thrust_copy(half_float::half const*, int32_t*, size_t);
template void Aidge::thrust_copy(half_float::half const*, int16_t*, size_t);
template void Aidge::thrust_copy(half_float::half const*, int8_t*, size_t);
template void Aidge::thrust_copy(half_float::half const*, uint64_t*, size_t);
template void Aidge::thrust_copy(half_float::half const*, uint32_t*, size_t);
template void Aidge::thrust_copy(half_float::half const*, uint16_t*, size_t);
template void Aidge::thrust_copy(half_float::half const*, uint8_t*, size_t);
// int64_t
template<> void Aidge::thrust_copy(int64_t const*, double*, size_t);
template<> void Aidge::thrust_copy(int64_t const*, float*, size_t);
template<> void Aidge::thrust_copy(int64_t const*, half_float::half*, size_t);
template<> void Aidge::thrust_copy(int64_t const*, int64_t*, size_t);
template<> void Aidge::thrust_copy(int64_t const*, int32_t*, size_t);
template<> void Aidge::thrust_copy(int64_t const*, int16_t*, size_t);
template<> void Aidge::thrust_copy(int64_t const*, int8_t*, size_t);
template<> void Aidge::thrust_copy(int64_t const*, uint64_t*, size_t);
template<> void Aidge::thrust_copy(int64_t const*, uint32_t*, size_t);
template<> void Aidge::thrust_copy(int64_t const*, uint16_t*, size_t);
template<> void Aidge::thrust_copy(int64_t const*, uint8_t*, size_t);
template void Aidge::thrust_copy(int64_t const*, double*, size_t);
template void Aidge::thrust_copy(int64_t const*, float*, size_t);
template void Aidge::thrust_copy(int64_t const*, half_float::half*, size_t);
template void Aidge::thrust_copy(int64_t const*, int64_t*, size_t);
template void Aidge::thrust_copy(int64_t const*, int32_t*, size_t);
template void Aidge::thrust_copy(int64_t const*, int16_t*, size_t);
template void Aidge::thrust_copy(int64_t const*, int8_t*, size_t);
template void Aidge::thrust_copy(int64_t const*, uint64_t*, size_t);
template void Aidge::thrust_copy(int64_t const*, uint32_t*, size_t);
template void Aidge::thrust_copy(int64_t const*, uint16_t*, size_t);
template void Aidge::thrust_copy(int64_t const*, uint8_t*, size_t);
// int32_t
template<> void Aidge::thrust_copy(int32_t const*, double*, size_t);
template<> void Aidge::thrust_copy(int32_t const*, float*, size_t);
template<> void Aidge::thrust_copy(int32_t const*, half_float::half*, size_t);
template<> void Aidge::thrust_copy(int32_t const*, int64_t*, size_t);
template<> void Aidge::thrust_copy(int32_t const*, int32_t*, size_t);
template<> void Aidge::thrust_copy(int32_t const*, int16_t*, size_t);
template<> void Aidge::thrust_copy(int32_t const*, int8_t*, size_t);
template<> void Aidge::thrust_copy(int32_t const*, uint64_t*, size_t);
template<> void Aidge::thrust_copy(int32_t const*, uint32_t*, size_t);
template<> void Aidge::thrust_copy(int32_t const*, uint16_t*, size_t);
template<> void Aidge::thrust_copy(int32_t const*, uint8_t*, size_t);
template void Aidge::thrust_copy(int32_t const*, double*, size_t);
template void Aidge::thrust_copy(int32_t const*, float*, size_t);
template void Aidge::thrust_copy(int32_t const*, half_float::half*, size_t);
template void Aidge::thrust_copy(int32_t const*, int64_t*, size_t);
template void Aidge::thrust_copy(int32_t const*, int32_t*, size_t);
template void Aidge::thrust_copy(int32_t const*, int16_t*, size_t);
template void Aidge::thrust_copy(int32_t const*, int8_t*, size_t);
template void Aidge::thrust_copy(int32_t const*, uint64_t*, size_t);
template void Aidge::thrust_copy(int32_t const*, uint32_t*, size_t);
template void Aidge::thrust_copy(int32_t const*, uint16_t*, size_t);
template void Aidge::thrust_copy(int32_t const*, uint8_t*, size_t);
// int16_t
template<> void Aidge::thrust_copy(int16_t const*, double*, size_t);
template<> void Aidge::thrust_copy(int16_t const*, float*, size_t);
template<> void Aidge::thrust_copy(int16_t const*, half_float::half*, size_t);
template<> void Aidge::thrust_copy(int16_t const*, int64_t*, size_t);
template<> void Aidge::thrust_copy(int16_t const*, int32_t*, size_t);
template<> void Aidge::thrust_copy(int16_t const*, int16_t*, size_t);
template<> void Aidge::thrust_copy(int16_t const*, int8_t*, size_t);
template<> void Aidge::thrust_copy(int16_t const*, uint64_t*, size_t);
template<> void Aidge::thrust_copy(int16_t const*, uint32_t*, size_t);
template<> void Aidge::thrust_copy(int16_t const*, uint16_t*, size_t);
template<> void Aidge::thrust_copy(int16_t const*, uint8_t*, size_t);
template void Aidge::thrust_copy(int16_t const*, double*, size_t);
template void Aidge::thrust_copy(int16_t const*, float*, size_t);
template void Aidge::thrust_copy(int16_t const*, half_float::half*, size_t);
template void Aidge::thrust_copy(int16_t const*, int64_t*, size_t);
template void Aidge::thrust_copy(int16_t const*, int32_t*, size_t);
template void Aidge::thrust_copy(int16_t const*, int16_t*, size_t);
template void Aidge::thrust_copy(int16_t const*, int8_t*, size_t);
template void Aidge::thrust_copy(int16_t const*, uint64_t*, size_t);
template void Aidge::thrust_copy(int16_t const*, uint32_t*, size_t);
template void Aidge::thrust_copy(int16_t const*, uint16_t*, size_t);
template void Aidge::thrust_copy(int16_t const*, uint8_t*, size_t);
// int8_t
template<> void Aidge::thrust_copy(int8_t const*, double*, size_t);
template<> void Aidge::thrust_copy(int8_t const*, float*, size_t);
template<> void Aidge::thrust_copy(int8_t const*, half_float::half*, size_t);
template<> void Aidge::thrust_copy(int8_t const*, int64_t*, size_t);
template<> void Aidge::thrust_copy(int8_t const*, int32_t*, size_t);
template<> void Aidge::thrust_copy(int8_t const*, int16_t*, size_t);
template<> void Aidge::thrust_copy(int8_t const*, int8_t*, size_t);
template<> void Aidge::thrust_copy(int8_t const*, uint64_t*, size_t);
template<> void Aidge::thrust_copy(int8_t const*, uint32_t*, size_t);
template<> void Aidge::thrust_copy(int8_t const*, uint16_t*, size_t);
template<> void Aidge::thrust_copy(int8_t const*, uint8_t*, size_t);
template void Aidge::thrust_copy(int8_t const*, double*, size_t);
template void Aidge::thrust_copy(int8_t const*, float*, size_t);
template void Aidge::thrust_copy(int8_t const*, half_float::half*, size_t);
template void Aidge::thrust_copy(int8_t const*, int64_t*, size_t);
template void Aidge::thrust_copy(int8_t const*, int32_t*, size_t);
template void Aidge::thrust_copy(int8_t const*, int16_t*, size_t);
template void Aidge::thrust_copy(int8_t const*, int8_t*, size_t);
template void Aidge::thrust_copy(int8_t const*, uint64_t*, size_t);
template void Aidge::thrust_copy(int8_t const*, uint32_t*, size_t);
template void Aidge::thrust_copy(int8_t const*, uint16_t*, size_t);
template void Aidge::thrust_copy(int8_t const*, uint8_t*, size_t);
// uint64_t
template<> void Aidge::thrust_copy(uint64_t const*, double*, size_t);
template<> void Aidge::thrust_copy(uint64_t const*, float*, size_t);
template<> void Aidge::thrust_copy(uint64_t const*, half_float::half*, size_t);
template<> void Aidge::thrust_copy(uint64_t const*, int64_t*, size_t);
template<> void Aidge::thrust_copy(uint64_t const*, int32_t*, size_t);
template<> void Aidge::thrust_copy(uint64_t const*, int16_t*, size_t);
template<> void Aidge::thrust_copy(uint64_t const*, int8_t*, size_t);
template<> void Aidge::thrust_copy(uint64_t const*, uint64_t*, size_t);
template<> void Aidge::thrust_copy(uint64_t const*, uint32_t*, size_t);
template<> void Aidge::thrust_copy(uint64_t const*, uint16_t*, size_t);
template<> void Aidge::thrust_copy(uint64_t const*, uint8_t*, size_t);
template void Aidge::thrust_copy(uint64_t const*, double*, size_t);
template void Aidge::thrust_copy(uint64_t const*, float*, size_t);
template void Aidge::thrust_copy(uint64_t const*, half_float::half*, size_t);
template void Aidge::thrust_copy(uint64_t const*, int64_t*, size_t);
template void Aidge::thrust_copy(uint64_t const*, int32_t*, size_t);
template void Aidge::thrust_copy(uint64_t const*, int16_t*, size_t);
template void Aidge::thrust_copy(uint64_t const*, int8_t*, size_t);
template void Aidge::thrust_copy(uint64_t const*, uint64_t*, size_t);
template void Aidge::thrust_copy(uint64_t const*, uint32_t*, size_t);
template void Aidge::thrust_copy(uint64_t const*, uint16_t*, size_t);
template void Aidge::thrust_copy(uint64_t const*, uint8_t*, size_t);
// uint32_t
template<> void Aidge::thrust_copy(uint32_t const*, double*, size_t);
template<> void Aidge::thrust_copy(uint32_t const*, float*, size_t);
template<> void Aidge::thrust_copy(uint32_t const*, half_float::half*, size_t);
template<> void Aidge::thrust_copy(uint32_t const*, int64_t*, size_t);
template<> void Aidge::thrust_copy(uint32_t const*, int32_t*, size_t);
template<> void Aidge::thrust_copy(uint32_t const*, int16_t*, size_t);
template<> void Aidge::thrust_copy(uint32_t const*, int8_t*, size_t);
template<> void Aidge::thrust_copy(uint32_t const*, uint64_t*, size_t);
template<> void Aidge::thrust_copy(uint32_t const*, uint32_t*, size_t);
template<> void Aidge::thrust_copy(uint32_t const*, uint16_t*, size_t);
template<> void Aidge::thrust_copy(uint32_t const*, uint8_t*, size_t);
template void Aidge::thrust_copy(uint32_t const*, double*, size_t);
template void Aidge::thrust_copy(uint32_t const*, float*, size_t);
template void Aidge::thrust_copy(uint32_t const*, half_float::half*, size_t);
template void Aidge::thrust_copy(uint32_t const*, int64_t*, size_t);
template void Aidge::thrust_copy(uint32_t const*, int32_t*, size_t);
template void Aidge::thrust_copy(uint32_t const*, int16_t*, size_t);
template void Aidge::thrust_copy(uint32_t const*, int8_t*, size_t);
template void Aidge::thrust_copy(uint32_t const*, uint64_t*, size_t);
template void Aidge::thrust_copy(uint32_t const*, uint32_t*, size_t);
template void Aidge::thrust_copy(uint32_t const*, uint16_t*, size_t);
template void Aidge::thrust_copy(uint32_t const*, uint8_t*, size_t);
// uint16_t
template<> void Aidge::thrust_copy(uint16_t const*, double*, size_t);
template<> void Aidge::thrust_copy(uint16_t const*, float*, size_t);
template<> void Aidge::thrust_copy(uint16_t const*, half_float::half*, size_t);
template<> void Aidge::thrust_copy(uint16_t const*, int64_t*, size_t);
template<> void Aidge::thrust_copy(uint16_t const*, int32_t*, size_t);
template<> void Aidge::thrust_copy(uint16_t const*, int16_t*, size_t);
template<> void Aidge::thrust_copy(uint16_t const*, int8_t*, size_t);
template<> void Aidge::thrust_copy(uint16_t const*, uint64_t*, size_t);
template<> void Aidge::thrust_copy(uint16_t const*, uint32_t*, size_t);
template<> void Aidge::thrust_copy(uint16_t const*, uint16_t*, size_t);
template<> void Aidge::thrust_copy(uint16_t const*, uint8_t*, size_t);
template void Aidge::thrust_copy(uint16_t const*, double*, size_t);
template void Aidge::thrust_copy(uint16_t const*, float*, size_t);
template void Aidge::thrust_copy(uint16_t const*, half_float::half*, size_t);
template void Aidge::thrust_copy(uint16_t const*, int64_t*, size_t);
template void Aidge::thrust_copy(uint16_t const*, int32_t*, size_t);
template void Aidge::thrust_copy(uint16_t const*, int16_t*, size_t);
template void Aidge::thrust_copy(uint16_t const*, int8_t*, size_t);
template void Aidge::thrust_copy(uint16_t const*, uint64_t*, size_t);
template void Aidge::thrust_copy(uint16_t const*, uint32_t*, size_t);
template void Aidge::thrust_copy(uint16_t const*, uint16_t*, size_t);
template void Aidge::thrust_copy(uint16_t const*, uint8_t*, size_t);
// uint8_t
template<> void Aidge::thrust_copy(uint8_t const*, double*, size_t);
template<> void Aidge::thrust_copy(uint8_t const*, float*, size_t);
template<> void Aidge::thrust_copy(uint8_t const*, half_float::half*, size_t);
template<> void Aidge::thrust_copy(uint8_t const*, int64_t*, size_t);
template<> void Aidge::thrust_copy(uint8_t const*, int32_t*, size_t);
template<> void Aidge::thrust_copy(uint8_t const*, int16_t*, size_t);
template<> void Aidge::thrust_copy(uint8_t const*, int8_t*, size_t);
template<> void Aidge::thrust_copy(uint8_t const*, uint64_t*, size_t);
template<> void Aidge::thrust_copy(uint8_t const*, uint32_t*, size_t);
template<> void Aidge::thrust_copy(uint8_t const*, uint16_t*, size_t);
template<> void Aidge::thrust_copy(uint8_t const*, uint8_t*, size_t);
template void Aidge::thrust_copy(uint8_t const*, double*, size_t);
template void Aidge::thrust_copy(uint8_t const*, float*, size_t);
template void Aidge::thrust_copy(uint8_t const*, half_float::half*, size_t);
template void Aidge::thrust_copy(uint8_t const*, int64_t*, size_t);
template void Aidge::thrust_copy(uint8_t const*, int32_t*, size_t);
template void Aidge::thrust_copy(uint8_t const*, int16_t*, size_t);
template void Aidge::thrust_copy(uint8_t const*, int8_t*, size_t);
template void Aidge::thrust_copy(uint8_t const*, uint64_t*, size_t);
template void Aidge::thrust_copy(uint8_t const*, uint32_t*, size_t);
template void Aidge::thrust_copy(uint8_t const*, uint16_t*, size_t);
template void Aidge::thrust_copy(uint8_t const*, uint8_t*, size_t);
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment