Skip to content
Snippets Groups Projects
Commit 1ccc5bae authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added half support in CudaContext.hpp

parent 1c8e1637
No related branches found
No related tags found
1 merge request!4Add Convert operator (a.k.a. Transmitter)
Pipeline #35595 failed
...@@ -128,6 +128,11 @@ public: ...@@ -128,6 +128,11 @@ public:
} }
namespace Aidge { namespace Aidge {
template <>
struct CudaContext::data_type<half_float::half> {
static const cudnnDataType_t value = CUDNN_DATA_HALF;
};
template <> template <>
struct CudaContext::data_type<float> { struct CudaContext::data_type<float> {
static const cudnnDataType_t value = CUDNN_DATA_FLOAT; static const cudnnDataType_t value = CUDNN_DATA_FLOAT;
...@@ -139,25 +144,25 @@ namespace Aidge { ...@@ -139,25 +144,25 @@ namespace Aidge {
}; };
inline cudnnDataType_t DataTypeToCudnn(DataType type) { inline cudnnDataType_t DataTypeToCudnn(DataType type) {
if (type == DataType::Float32) switch (type) {
return CUDNN_DATA_FLOAT; case DataType::Float64:
if (type == DataType::Float64)
return CUDNN_DATA_DOUBLE; return CUDNN_DATA_DOUBLE;
case DataType::Float32:
if (type == DataType::Int8) return CUDNN_DATA_FLOAT;
case DataType::Float16:
return CUDNN_DATA_HALF;
case DataType::Int8:
return CUDNN_DATA_INT8; return CUDNN_DATA_INT8;
case DataType::UInt8:
if (type == DataType::UInt8)
return CUDNN_DATA_UINT8; return CUDNN_DATA_UINT8;
case DataType::Int32:
if (type == DataType::Int32)
return CUDNN_DATA_INT32; return CUDNN_DATA_INT32;
case DataType::Int64:
if (type == DataType::Int64)
return CUDNN_DATA_INT64; return CUDNN_DATA_INT64;
default:
assert(false && "Unsupported CuDNN type"); assert(false && "Unsupported CuDNN type");
}
return CUDNN_DATA_FLOAT; // TODO: undefined behavior return CUDNN_DATA_FLOAT; // TODO: undefined behavior
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment