[Upd] tensorImpl 'data()' and switch cases
parent
62519657
No related branches found
No related tags found
This commit is part of merge request !31. Comments created here will be created in the context of that merge request.
... | @@ -27,7 +27,7 @@ class TensorImpl_cpu : public TensorImpl { | ... | @@ -27,7 +27,7 @@ class TensorImpl_cpu : public TensorImpl { |
bool operator==(const TensorImpl &otherImpl) const override final { | bool operator==(const TensorImpl &otherImpl) const override final { | ||
const auto& typedOtherImpl = reinterpret_cast<const TensorImpl_cpu<T> &>(otherImpl); | const auto& typedOtherImpl = reinterpret_cast<const TensorImpl_cpu<T> &>(otherImpl); | ||
AIDGE_INTERNAL_ASSERT(typedOtherImpl.data().size() >= mTensor.size()); | AIDGE_INTERNAL_ASSERT(typedOtherImpl.size() >= mTensor.size()); | ||
std::size_t i = 0; | std::size_t i = 0; | ||
for (; i < mTensor.size() && | for (; i < mTensor.size() && | ||
... | @@ -42,7 +42,7 @@ class TensorImpl_cpu : public TensorImpl { | ... | @@ -42,7 +42,7 @@ class TensorImpl_cpu : public TensorImpl { |
} | } | ||
// native interface | // native interface | ||
const future_std::span<T>& data() const { return mData; } | auto data() const -> decltype(mData.data()) { return mData.data(); } | ||
std::size_t size() const override { return mData.size(); } | std::size_t size() const override { return mData.size(); } | ||
std::size_t scalarSize() const override { return sizeof(T); } | std::size_t scalarSize() const override { return sizeof(T); } | ||
... | @@ -63,52 +63,55 @@ class TensorImpl_cpu : public TensorImpl { | ... | @@ -63,52 +63,55 @@ class TensorImpl_cpu : public TensorImpl { |
} | } | ||
AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity"); | AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity"); | ||
if (srcDt == DataType::Float64) { | switch (srcDt) | ||
std::copy(static_cast<const double*>(src), static_cast<const double*>(src) + length, | { | ||
static_cast<T *>(rawPtr())); | case DataType::Float64: | ||
} | std::copy(static_cast<const double*>(src), static_cast<const double*>(src) + length, | ||
else if (srcDt == DataType::Float32) { | static_cast<T *>(rawPtr())); | ||
std::copy(static_cast<const float*>(src), static_cast<const float*>(src) + length, | break; | ||
static_cast<T *>(rawPtr())); | case DataType::Float32: | ||
} | std::copy(static_cast<const float*>(src), static_cast<const float*>(src) + length, | ||
else if (srcDt == DataType::Float16) { | static_cast<T *>(rawPtr())); | ||
std::copy(static_cast<const half_float::half*>(src), static_cast<const half_float::half*>(src) + length, | break; | ||
static_cast<T *>(rawPtr())); | case DataType::Float16: | ||
} | std::copy(static_cast<const half_float::half*>(src), static_cast<const half_float::half*>(src) + length, | ||
else if (srcDt == DataType::Int64) { | static_cast<T *>(rawPtr())); | ||
std::copy(static_cast<const int64_t*>(src), static_cast<const int64_t*>(src) + length, | break; | ||
static_cast<T *>(rawPtr())); | case DataType::Int64: | ||
} | std::copy(static_cast<const int64_t*>(src), static_cast<const int64_t*>(src) + length, | ||
else if (srcDt == DataType::UInt64) { | static_cast<T *>(rawPtr())); | ||
std::copy(static_cast<const uint64_t*>(src), static_cast<const uint64_t*>(src) + length, | break; | ||
static_cast<T *>(rawPtr())); | case DataType::UInt64: | ||
} | std::copy(static_cast<const uint64_t*>(src), static_cast<const uint64_t*>(src) + length, | ||
else if (srcDt == DataType::Int32) { | static_cast<T *>(rawPtr())); | ||
std::copy(static_cast<const int32_t*>(src), static_cast<const int32_t*>(src) + length, | break; | ||
static_cast<T *>(rawPtr())); | case DataType::Int32: | ||
} | std::copy(static_cast<const int32_t*>(src), static_cast<const int32_t*>(src) + length, | ||
else if (srcDt == DataType::UInt32) { | static_cast<T *>(rawPtr())); | ||
std::copy(static_cast<const uint32_t*>(src), static_cast<const uint32_t*>(src) + length, | break; | ||
static_cast<T *>(rawPtr())); | case DataType::UInt32: | ||
} | std::copy(static_cast<const uint32_t*>(src), static_cast<const uint32_t*>(src) + length, | ||
else if (srcDt == DataType::Int16) { | static_cast<T *>(rawPtr())); | ||
std::copy(static_cast<const int16_t*>(src), static_cast<const int16_t*>(src) + length, | break; | ||
static_cast<T *>(rawPtr())); | case DataType::Int16: | ||
} | std::copy(static_cast<const int16_t*>(src), static_cast<const int16_t*>(src) + length, | ||
else if (srcDt == DataType::UInt16) { | static_cast<T *>(rawPtr())); | ||
std::copy(static_cast<const uint16_t*>(src), static_cast<const uint16_t*>(src) + length, | break; | ||
static_cast<T *>(rawPtr())); | case DataType::UInt16: | ||
} | std::copy(static_cast<const uint16_t*>(src), static_cast<const uint16_t*>(src) + length, | ||
else if (srcDt == DataType::Int8) { | static_cast<T *>(rawPtr())); | ||
std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, | break; | ||
static_cast<T *>(rawPtr())); | case ataType::Int8: | ||
|
|||
} | std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, | ||
else if (srcDt == DataType::UInt8) { | static_cast<T *>(rawPtr())); | ||
std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, | break; | ||
static_cast<T *>(rawPtr())); | case DataType::UInt8: | ||
} | std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, | ||
else { | static_cast<T *>(rawPtr())); | ||
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type."); | break; | ||
default: | |||
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type."); | |||
break; | |||
} | } | ||
} | } | ||
... | ... |