Skip to content
Snippets Groups Projects
Commit f80c3d59 authored by Charles Villard's avatar Charles Villard Committed by Maxence Naud
Browse files

edit: TensorImpl.{hpp,cu}: Fixed template in the general case

parent 6c3aa17b
No related branches found
No related tags found
2 merge requests!75Update 0.5.1 -> 0.6.0,!40Fix template compilation and warnings
......@@ -22,17 +22,18 @@
namespace Aidge {
template <typename SRC_T, typename DST_T>
void thrust_copy(const SRC_T* /*srcData*/, DST_T* /*dstData*/, size_t /*size*/);
typename std::enable_if<!std::is_same<half_float::half, DST_T>::value && !std::is_same<half_float::half, SRC_T>::value>::type
thrust_copy(const SRC_T* /*srcData*/, DST_T* /*dstData*/, size_t /*size*/);
template <typename SRC_T>
template <typename SRC_T, typename DST_T = half_float::half>
typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type
thrust_copy(const SRC_T *srcData, half_float::half *dstData, size_t size);
template <typename DST_T>
template <typename SRC_T = half_float::half, typename DST_T>
typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type
thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size);
template <>
template <typename SRC_T = half_float::half, typename DST_T = half_float::half>
void thrust_copy(const half_float::half* srcData, half_float::half* dstData, size_t size);
/**
......@@ -121,57 +122,57 @@ public:
case DataType::Float64:
thrust_copy(static_cast<const double*>(src),
static_cast<T*>(rawPtr(offset)),
length);
static_cast<size_t>(length));
break;
case DataType::Float32:
thrust_copy(static_cast<const float*>(src),
static_cast<T*>(rawPtr(offset)),
length);
static_cast<size_t>(length));
break;
case DataType::Float16:
thrust_copy(static_cast<const half_float::half*>(src),
static_cast<T*>(rawPtr(offset)),
length);
static_cast<size_t>(length));
break;
case DataType::Int64:
thrust_copy(static_cast<const int64_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
static_cast<size_t>(length));
break;
case DataType::UInt64:
thrust_copy(static_cast<const uint64_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
static_cast<size_t>(length));
break;
case DataType::Int32:
thrust_copy(static_cast<const int32_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
static_cast<size_t>(length));
break;
case DataType::UInt32:
thrust_copy(static_cast<const uint32_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
static_cast<size_t>(length));
break;
case DataType::Int16:
thrust_copy(static_cast<const int16_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
static_cast<size_t>(length));
break;
case DataType::UInt16:
thrust_copy(static_cast<const uint16_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
static_cast<size_t>(length));
break;
case DataType::Int8:
thrust_copy(static_cast<const int8_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
static_cast<size_t>(length));
break;
case DataType::UInt8:
thrust_copy(static_cast<const uint8_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
static_cast<size_t>(length));
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "TensorImpl_cuda<{}>::copyCast(): unsupported data type {}.", typeid(T).name(), srcDt);
......
......@@ -15,7 +15,8 @@
#include <thrust/device_ptr.h>
template <typename SRC_T, typename DST_T>
void Aidge::thrust_copy(const SRC_T* srcData, DST_T* dstData, size_t size)
typename std::enable_if<!std::is_same<half_float::half, DST_T>::value && !std::is_same<half_float::half, SRC_T>::value>::type
Aidge::thrust_copy(const SRC_T* srcData, DST_T* dstData, size_t size)
{
const thrust::device_ptr<const SRC_T> thrustSrcPtr(srcData);
thrust::device_ptr<DST_T> thrustDstPtr(dstData);
......@@ -36,7 +37,7 @@ cudaCopyToH_kernel(const SRC_T* srcData,
}
}
template <typename SRC_T>
template <typename SRC_T, typename DST_T>
typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type
Aidge::thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size)
{
......@@ -59,7 +60,7 @@ cudaCopyFromH_kernel(const __half* srcData,
}
}
template <typename DST_T>
template <typename SRC_T, typename DST_T>
typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type
Aidge::thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size)
{
......@@ -81,7 +82,7 @@ cudaCopyHToH_kernel(const __half* srcData,
}
}
template <>
template <typename SRC_T, typename DST_T>
void Aidge::thrust_copy(const half_float::half* srcData, half_float::half* dstData, size_t size)
{
cudaCopyHToH_kernel<<<(size + 255) / 256, 256>>>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment