From 1f0339d8f2044b5c4ed121f1db50fc725e6d0087 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Sun, 10 Dec 2023 19:07:34 +0100 Subject: [PATCH] Added missing float16 type --- include/aidge/backend/cuda/data/TensorImpl.hpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp index 45f8a6c..7d7b33f 100644 --- a/include/aidge/backend/cuda/data/TensorImpl.hpp +++ b/include/aidge/backend/cuda/data/TensorImpl.hpp @@ -96,6 +96,11 @@ public: static_cast<T*>(rawPtr()), length); } + else if (srcDt == DataType::Float16) { + thrust_copy(static_cast<const half_float::half*>(src), + static_cast<T*>(rawPtr()), + length); + } else if (srcDt == DataType::Int64) { thrust_copy(static_cast<const int64_t*>(src), static_cast<T*>(rawPtr()), @@ -230,6 +235,8 @@ static Registrar<Tensor> registrarTensorImpl_cuda_Float64( {"cuda", DataType::Float64}, Aidge::TensorImpl_cuda<double>::create); static Registrar<Tensor> registrarTensorImpl_cuda_Float32( {"cuda", DataType::Float32}, Aidge::TensorImpl_cuda<float>::create); +static Registrar<Tensor> registrarTensorImpl_cuda_Float32( + {"cuda", DataType::Float16}, Aidge::TensorImpl_cuda<half_float::half>::create); static Registrar<Tensor> registrarTensorImpl_cuda_Int32( {"cuda", DataType::Int32}, Aidge::TensorImpl_cuda<int>::create); } // namespace -- GitLab