From 69cc72c54a0aab17b694a3d7b7a013a333c6ff82 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Fri, 12 Jan 2024 14:26:50 +0000 Subject: [PATCH] [Upd] tensorImpl 'data()' and switch cases --- include/aidge/backend/cpu/data/TensorImpl.hpp | 99 ++++++++++--------- 1 file changed, 51 insertions(+), 48 deletions(-) diff --git a/include/aidge/backend/cpu/data/TensorImpl.hpp b/include/aidge/backend/cpu/data/TensorImpl.hpp index cce8c3f6..b02c9ef2 100644 --- a/include/aidge/backend/cpu/data/TensorImpl.hpp +++ b/include/aidge/backend/cpu/data/TensorImpl.hpp @@ -27,7 +27,7 @@ class TensorImpl_cpu : public TensorImpl { bool operator==(const TensorImpl &otherImpl) const override final { const auto& typedOtherImpl = reinterpret_cast<const TensorImpl_cpu<T> &>(otherImpl); - AIDGE_INTERNAL_ASSERT(typedOtherImpl.data().size() >= mTensor.size()); + AIDGE_INTERNAL_ASSERT(typedOtherImpl.size() >= mTensor.size()); std::size_t i = 0; for (; i < mTensor.size() && @@ -42,7 +42,7 @@ class TensorImpl_cpu : public TensorImpl { } // native interface - const future_std::span<T>& data() const { return mData; } + auto data() const -> decltype(mData.data()) { return mData.data(); } std::size_t size() const override { return mData.size(); } std::size_t scalarSize() const override { return sizeof(T); } @@ -63,52 +63,55 @@ class TensorImpl_cpu : public TensorImpl { } AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity"); - if (srcDt == DataType::Float64) { - std::copy(static_cast<const double*>(src), static_cast<const double*>(src) + length, - static_cast<T *>(rawPtr())); - } - else if (srcDt == DataType::Float32) { - std::copy(static_cast<const float*>(src), static_cast<const float*>(src) + length, - static_cast<T *>(rawPtr())); - } - else if (srcDt == DataType::Float16) { - std::copy(static_cast<const half_float::half*>(src), static_cast<const half_float::half*>(src) + length, - static_cast<T *>(rawPtr())); - } - else if (srcDt == DataType::Int64) { - std::copy(static_cast<const int64_t*>(src), static_cast<const int64_t*>(src) + length, - static_cast<T *>(rawPtr())); - } - else if (srcDt == DataType::UInt64) { - std::copy(static_cast<const uint64_t*>(src), static_cast<const uint64_t*>(src) + length, - static_cast<T *>(rawPtr())); - } - else if (srcDt == DataType::Int32) { - std::copy(static_cast<const int32_t*>(src), static_cast<const int32_t*>(src) + length, - static_cast<T *>(rawPtr())); - } - else if (srcDt == DataType::UInt32) { - std::copy(static_cast<const uint32_t*>(src), static_cast<const uint32_t*>(src) + length, - static_cast<T *>(rawPtr())); - } - else if (srcDt == DataType::Int16) { - std::copy(static_cast<const int16_t*>(src), static_cast<const int16_t*>(src) + length, - static_cast<T *>(rawPtr())); - } - else if (srcDt == DataType::UInt16) { - std::copy(static_cast<const uint16_t*>(src), static_cast<const uint16_t*>(src) + length, - static_cast<T *>(rawPtr())); - } - else if (srcDt == DataType::Int8) { - std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, - static_cast<T *>(rawPtr())); - } - else if (srcDt == DataType::UInt8) { - std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, - static_cast<T *>(rawPtr())); - } - else { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type."); + switch (srcDt) + { + case DataType::Float64: + std::copy(static_cast<const double*>(src), static_cast<const double*>(src) + length, + static_cast<T *>(rawPtr())); + break; + case DataType::Float32: + std::copy(static_cast<const float*>(src), static_cast<const float*>(src) + length, + static_cast<T *>(rawPtr())); + break; + case DataType::Float16: + std::copy(static_cast<const half_float::half*>(src), static_cast<const half_float::half*>(src) + length, + static_cast<T *>(rawPtr())); + break; + case DataType::Int64: + std::copy(static_cast<const int64_t*>(src), static_cast<const int64_t*>(src) + length, + static_cast<T *>(rawPtr())); + break; + case DataType::UInt64: + std::copy(static_cast<const uint64_t*>(src), static_cast<const uint64_t*>(src) + length, + static_cast<T *>(rawPtr())); + break; + case DataType::Int32: + std::copy(static_cast<const int32_t*>(src), static_cast<const int32_t*>(src) + length, + static_cast<T *>(rawPtr())); + break; + case DataType::UInt32: + std::copy(static_cast<const uint32_t*>(src), static_cast<const uint32_t*>(src) + length, + static_cast<T *>(rawPtr())); + break; + case DataType::Int16: + std::copy(static_cast<const int16_t*>(src), static_cast<const int16_t*>(src) + length, + static_cast<T *>(rawPtr())); + break; + case DataType::UInt16: + std::copy(static_cast<const uint16_t*>(src), static_cast<const uint16_t*>(src) + length, + static_cast<T *>(rawPtr())); + break; + case ataType::Int8: + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + static_cast<T *>(rawPtr())); + break; + case DataType::UInt8: + std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, + static_cast<T *>(rawPtr())); + break; + default: + AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type."); + break; } } -- GitLab