Skip to content
Snippets Groups Projects

Fix template compilation and warnings

Merged Charles Villard requested to merge silvanosky/aidge_backend_cuda:compilation_fix1 into dev
Compare and
3 files
+ 205
181
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -7,6 +7,7 @@
#include <vector>
#include <cuda.h>
#include <type_traits> // std::enable_if, std::is_same
#include "aidge/backend/TensorImpl.hpp"
#include "aidge/data/Tensor.hpp"
@@ -20,14 +21,9 @@
namespace Aidge {
template <typename SRC_T, typename DST_T>
void thrust_copy(const SRC_T* /*srcData*/, DST_T* /*dstData*/, size_t /*size*/);
template <typename SRC_T, typename std::enable_if<!std::is_same<half_float::half, SRC_T>::value>::type* = nullptr>
void thrust_copy(const SRC_T* srcData, half_float::half* dstData, size_t size);
template <typename DST_T, typename std::enable_if<!std::is_same<half_float::half, DST_T>::value>::type* = nullptr>
void thrust_copy(const half_float::half* srcData, DST_T* dstData, size_t size);
template <>
void thrust_copy(const half_float::half* srcData, half_float::half* dstData, size_t size);
void thrust_copy(const SRC_T* srcData, DST_T* dstData, size_t size);
/**
* @brief Abstract class for the TensorImpl_cuda class template.
@@ -115,57 +111,57 @@ public:
case DataType::Float64:
thrust_copy(static_cast<const double*>(src),
static_cast<T*>(rawPtr(offset)),
length);
static_cast<size_t>(length));
break;
case DataType::Float32:
thrust_copy(static_cast<const float*>(src),
static_cast<T*>(rawPtr(offset)),
length);
static_cast<size_t>(length));
break;
case DataType::Float16:
thrust_copy(static_cast<const half_float::half*>(src),
static_cast<T*>(rawPtr(offset)),
length);
static_cast<size_t>(length));
break;
case DataType::Int64:
thrust_copy(static_cast<const int64_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
static_cast<size_t>(length));
break;
case DataType::UInt64:
thrust_copy(static_cast<const uint64_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
static_cast<size_t>(length));
break;
case DataType::Int32:
thrust_copy(static_cast<const int32_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
static_cast<size_t>(length));
break;
case DataType::UInt32:
thrust_copy(static_cast<const uint32_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
static_cast<size_t>(length));
break;
case DataType::Int16:
thrust_copy(static_cast<const int16_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
static_cast<size_t>(length));
break;
case DataType::UInt16:
thrust_copy(static_cast<const uint16_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
static_cast<size_t>(length));
break;
case DataType::Int8:
thrust_copy(static_cast<const int8_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
static_cast<size_t>(length));
break;
case DataType::UInt8:
thrust_copy(static_cast<const uint8_t*>(src),
static_cast<T*>(rawPtr(offset)),
length);
static_cast<size_t>(length));
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "TensorImpl_cuda<{}>::copyCast(): unsupported data type {}.", typeid(T).name(), srcDt);
Loading