diff --git a/include/aidge/backend/cuda/data/TensorImpl.hpp b/include/aidge/backend/cuda/data/TensorImpl.hpp index e2d53ffe38bc736126cb551e0a52ae31aa392521..f7de7aaf434ddf8a19f225f1c5db49780fde88d2 100644 --- a/include/aidge/backend/cuda/data/TensorImpl.hpp +++ b/include/aidge/backend/cuda/data/TensorImpl.hpp @@ -97,63 +97,65 @@ public: } AIDGE_ASSERT(length <= mData.size() || length <= mNbElts, "copy length is above capacity"); - if (srcDt == DataType::Float64) { + switch (srcDt) { + case DataType::Float64: thrust_copy(static_cast<const double*>(src), static_cast<T*>(rawPtr(offset)), length); - } - else if (srcDt == DataType::Float32) { + break; + case DataType::Float32: thrust_copy(static_cast<const float*>(src), static_cast<T*>(rawPtr(offset)), length); - } - else if (srcDt == DataType::Float16) { + break; + case DataType::Float16: thrust_copy(static_cast<const half_float::half*>(src), static_cast<T*>(rawPtr(offset)), length); - } - else if (srcDt == DataType::Int64) { + break; + case DataType::Int64: thrust_copy(static_cast<const int64_t*>(src), static_cast<T*>(rawPtr(offset)), length); - } - else if (srcDt == DataType::UInt64) { + break; + case DataType::UInt64: thrust_copy(static_cast<const uint64_t*>(src), static_cast<T*>(rawPtr(offset)), length); - } - else if (srcDt == DataType::Int32) { + break; + case DataType::Int32: thrust_copy(static_cast<const int32_t*>(src), static_cast<T*>(rawPtr(offset)), length); - } - else if (srcDt == DataType::UInt32) { + break; + case DataType::UInt32: thrust_copy(static_cast<const uint32_t*>(src), static_cast<T*>(rawPtr(offset)), length); - } - else if (srcDt == DataType::Int16) { + break; + case DataType::Int16: thrust_copy(static_cast<const int16_t*>(src), static_cast<T*>(rawPtr(offset)), length); - } - else if (srcDt == DataType::UInt16) { + break; + case DataType::UInt16: thrust_copy(static_cast<const uint16_t*>(src), static_cast<T*>(rawPtr(offset)), length); - } - else if (srcDt == DataType::Int8) { + break; + case DataType::Int8: thrust_copy(static_cast<const int8_t*>(src), static_cast<T*>(rawPtr(offset)), length); - } - else if (srcDt == DataType::UInt8) { + break; + case DataType::UInt8: thrust_copy(static_cast<const uint8_t*>(src), static_cast<T*>(rawPtr(offset)), length); - } - else { + break; + default: AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type."); + break; } }