Skip to content
Snippets Groups Projects
Commit 1f0339d8 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added missing float16 type

parent 4e00cee0
No related branches found
No related tags found
1 merge request!4Add Convert operator (a.k.a. Transmitter)
Pipeline #35590 canceled
...@@ -96,6 +96,11 @@ public: ...@@ -96,6 +96,11 @@ public:
static_cast<T*>(rawPtr()), static_cast<T*>(rawPtr()),
length); length);
} }
else if (srcDt == DataType::Float16) {
thrust_copy(static_cast<const half_float::half*>(src),
static_cast<T*>(rawPtr()),
length);
}
else if (srcDt == DataType::Int64) { else if (srcDt == DataType::Int64) {
thrust_copy(static_cast<const int64_t*>(src), thrust_copy(static_cast<const int64_t*>(src),
static_cast<T*>(rawPtr()), static_cast<T*>(rawPtr()),
...@@ -230,6 +235,8 @@ static Registrar<Tensor> registrarTensorImpl_cuda_Float64( ...@@ -230,6 +235,8 @@ static Registrar<Tensor> registrarTensorImpl_cuda_Float64(
{"cuda", DataType::Float64}, Aidge::TensorImpl_cuda<double>::create); {"cuda", DataType::Float64}, Aidge::TensorImpl_cuda<double>::create);
static Registrar<Tensor> registrarTensorImpl_cuda_Float32( static Registrar<Tensor> registrarTensorImpl_cuda_Float32(
{"cuda", DataType::Float32}, Aidge::TensorImpl_cuda<float>::create); {"cuda", DataType::Float32}, Aidge::TensorImpl_cuda<float>::create);
static Registrar<Tensor> registrarTensorImpl_cuda_Float32(
{"cuda", DataType::Float16}, Aidge::TensorImpl_cuda<half_float::half>::create);
static Registrar<Tensor> registrarTensorImpl_cuda_Int32( static Registrar<Tensor> registrarTensorImpl_cuda_Int32(
{"cuda", DataType::Int32}, Aidge::TensorImpl_cuda<int>::create); {"cuda", DataType::Int32}, Aidge::TensorImpl_cuda<int>::create);
} // namespace } // namespace
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment