diff --git a/include/aidge/backend/cuda/utils/CudaContext.hpp b/include/aidge/backend/cuda/utils/CudaContext.hpp
index a66ccdf690603c39ba4a7bf691f0dffea64ddddb..82dd395e6bbb33bae29c5d881290d6996bfb0332 100644
--- a/include/aidge/backend/cuda/utils/CudaContext.hpp
+++ b/include/aidge/backend/cuda/utils/CudaContext.hpp
@@ -128,6 +128,11 @@ public:
 }
 
 namespace Aidge {
+    template <>
+    struct CudaContext::data_type<half_float::half> {
+        static const cudnnDataType_t value = CUDNN_DATA_HALF;
+    };
+
     template <>
     struct CudaContext::data_type<float> {
         static const cudnnDataType_t value = CUDNN_DATA_FLOAT;
@@ -139,25 +144,25 @@ namespace Aidge {
     };
 
     inline cudnnDataType_t DataTypeToCudnn(DataType type) {
-        if (type == DataType::Float32)
-            return CUDNN_DATA_FLOAT;
-
-        if (type == DataType::Float64)
+        switch (type) {
+        case DataType::Float64:
             return CUDNN_DATA_DOUBLE;
-
-        if (type == DataType::Int8)
+        case DataType::Float32:
+            return CUDNN_DATA_FLOAT;
+        case DataType::Float16:
+            return CUDNN_DATA_HALF;
+        case DataType::Int8:
             return CUDNN_DATA_INT8;
-
-        if (type == DataType::UInt8)
+        case DataType::UInt8:
             return CUDNN_DATA_UINT8;
-
-        if (type == DataType::Int32)
+        case DataType::Int32:
             return CUDNN_DATA_INT32;
-
-        if (type == DataType::Int64)
+        case DataType::Int64:
             return CUDNN_DATA_INT64;
-        
-        assert(false && "Unsupported CuDNN type");
+        default:
+            assert(false && "Unsupported CuDNN type");
+        }
+
         return CUDNN_DATA_FLOAT;  // TODO: undefined behavior
     }
 }