diff --git a/include/aidge/backend/cuda/utils/CudaContext.hpp b/include/aidge/backend/cuda/utils/CudaContext.hpp index a66ccdf690603c39ba4a7bf691f0dffea64ddddb..82dd395e6bbb33bae29c5d881290d6996bfb0332 100644 --- a/include/aidge/backend/cuda/utils/CudaContext.hpp +++ b/include/aidge/backend/cuda/utils/CudaContext.hpp @@ -128,6 +128,11 @@ public: } namespace Aidge { + template <> + struct CudaContext::data_type<half_float::half> { + static const cudnnDataType_t value = CUDNN_DATA_HALF; + }; + template <> struct CudaContext::data_type<float> { static const cudnnDataType_t value = CUDNN_DATA_FLOAT; @@ -139,25 +144,25 @@ namespace Aidge { }; inline cudnnDataType_t DataTypeToCudnn(DataType type) { - if (type == DataType::Float32) - return CUDNN_DATA_FLOAT; - - if (type == DataType::Float64) + switch (type) { + case DataType::Float64: return CUDNN_DATA_DOUBLE; - - if (type == DataType::Int8) + case DataType::Float32: + return CUDNN_DATA_FLOAT; + case DataType::Float16: + return CUDNN_DATA_HALF; + case DataType::Int8: return CUDNN_DATA_INT8; - - if (type == DataType::UInt8) + case DataType::UInt8: return CUDNN_DATA_UINT8; - - if (type == DataType::Int32) + case DataType::Int32: return CUDNN_DATA_INT32; - - if (type == DataType::Int64) + case DataType::Int64: return CUDNN_DATA_INT64; - - assert(false && "Unsupported CuDNN type"); + default: + assert(false && "Unsupported CuDNN type"); + } + return CUDNN_DATA_FLOAT; // TODO: undefined behavior } }