From 1ccc5bae5fb6061ba9634a1a813922925a7cb3e0 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Sun, 10 Dec 2023 22:18:26 +0100 Subject: [PATCH] Added half support in CudaContext.hpp --- .../aidge/backend/cuda/utils/CudaContext.hpp | 33 +++++++++++-------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/include/aidge/backend/cuda/utils/CudaContext.hpp b/include/aidge/backend/cuda/utils/CudaContext.hpp index a66ccdf..82dd395 100644 --- a/include/aidge/backend/cuda/utils/CudaContext.hpp +++ b/include/aidge/backend/cuda/utils/CudaContext.hpp @@ -128,6 +128,11 @@ public: } namespace Aidge { + template <> + struct CudaContext::data_type<half_float::half> { + static const cudnnDataType_t value = CUDNN_DATA_HALF; + }; + template <> struct CudaContext::data_type<float> { static const cudnnDataType_t value = CUDNN_DATA_FLOAT; @@ -139,25 +144,25 @@ namespace Aidge { }; inline cudnnDataType_t DataTypeToCudnn(DataType type) { - if (type == DataType::Float32) - return CUDNN_DATA_FLOAT; - - if (type == DataType::Float64) + switch (type) { + case DataType::Float64: return CUDNN_DATA_DOUBLE; - - if (type == DataType::Int8) + case DataType::Float32: + return CUDNN_DATA_FLOAT; + case DataType::Float16: + return CUDNN_DATA_HALF; + case DataType::Int8: return CUDNN_DATA_INT8; - - if (type == DataType::UInt8) + case DataType::UInt8: return CUDNN_DATA_UINT8; - - if (type == DataType::Int32) + case DataType::Int32: return CUDNN_DATA_INT32; - - if (type == DataType::Int64) + case DataType::Int64: return CUDNN_DATA_INT64; - - assert(false && "Unsupported CuDNN type"); + default: + assert(false && "Unsupported CuDNN type"); + } + return CUDNN_DATA_FLOAT; // TODO: undefined behavior } } -- GitLab