diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp index 45f8a6c3bd8ac9e40f8f7716119e7a36ecfc1b3c..7d7b33f77237f40eb12d8aee794aaecf32804e15 100644 --- a/include/aidge/backend/cuda/data/TensorImpl.hpp +++ b/include/aidge/backend/cuda/data/TensorImpl.hpp @@ -96,6 +96,11 @@ public: static_cast<T*>(rawPtr()), length); } + else if (srcDt == DataType::Float16) { + thrust_copy(static_cast<const half_float::half*>(src), + static_cast<T*>(rawPtr()), + length); + } else if (srcDt == DataType::Int64) { thrust_copy(static_cast<const int64_t*>(src), static_cast<T*>(rawPtr()), @@ -230,6 +235,8 @@ static Registrar<Tensor> registrarTensorImpl_cuda_Float64( {"cuda", DataType::Float64}, Aidge::TensorImpl_cuda<double>::create); static Registrar<Tensor> registrarTensorImpl_cuda_Float32( {"cuda", DataType::Float32}, Aidge::TensorImpl_cuda<float>::create); +static Registrar<Tensor> registrarTensorImpl_cuda_Float32( + {"cuda", DataType::Float16}, Aidge::TensorImpl_cuda<half_float::half>::create); static Registrar<Tensor> registrarTensorImpl_cuda_Int32( {"cuda", DataType::Int32}, Aidge::TensorImpl_cuda<int>::create); } // namespace