Skip to content
Snippets Groups Projects
Commit 69cc72c5 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Upd] tensorImpl 'data()' and switch cases

parent 62519657
No related branches found
No related tags found
1 merge request!31Adding INT64 Tensor support
...@@ -27,7 +27,7 @@ class TensorImpl_cpu : public TensorImpl { ...@@ -27,7 +27,7 @@ class TensorImpl_cpu : public TensorImpl {
bool operator==(const TensorImpl &otherImpl) const override final { bool operator==(const TensorImpl &otherImpl) const override final {
const auto& typedOtherImpl = reinterpret_cast<const TensorImpl_cpu<T> &>(otherImpl); const auto& typedOtherImpl = reinterpret_cast<const TensorImpl_cpu<T> &>(otherImpl);
AIDGE_INTERNAL_ASSERT(typedOtherImpl.data().size() >= mTensor.size()); AIDGE_INTERNAL_ASSERT(typedOtherImpl.size() >= mTensor.size());
std::size_t i = 0; std::size_t i = 0;
for (; i < mTensor.size() && for (; i < mTensor.size() &&
...@@ -42,7 +42,7 @@ class TensorImpl_cpu : public TensorImpl { ...@@ -42,7 +42,7 @@ class TensorImpl_cpu : public TensorImpl {
} }
// native interface // native interface
const future_std::span<T>& data() const { return mData; } auto data() const -> decltype(mData.data()) { return mData.data(); }
std::size_t size() const override { return mData.size(); } std::size_t size() const override { return mData.size(); }
std::size_t scalarSize() const override { return sizeof(T); } std::size_t scalarSize() const override { return sizeof(T); }
...@@ -63,52 +63,55 @@ class TensorImpl_cpu : public TensorImpl { ...@@ -63,52 +63,55 @@ class TensorImpl_cpu : public TensorImpl {
} }
AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity"); AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity");
if (srcDt == DataType::Float64) { switch (srcDt)
std::copy(static_cast<const double*>(src), static_cast<const double*>(src) + length, {
static_cast<T *>(rawPtr())); case DataType::Float64:
} std::copy(static_cast<const double*>(src), static_cast<const double*>(src) + length,
else if (srcDt == DataType::Float32) { static_cast<T *>(rawPtr()));
std::copy(static_cast<const float*>(src), static_cast<const float*>(src) + length, break;
static_cast<T *>(rawPtr())); case DataType::Float32:
} std::copy(static_cast<const float*>(src), static_cast<const float*>(src) + length,
else if (srcDt == DataType::Float16) { static_cast<T *>(rawPtr()));
std::copy(static_cast<const half_float::half*>(src), static_cast<const half_float::half*>(src) + length, break;
static_cast<T *>(rawPtr())); case DataType::Float16:
} std::copy(static_cast<const half_float::half*>(src), static_cast<const half_float::half*>(src) + length,
else if (srcDt == DataType::Int64) { static_cast<T *>(rawPtr()));
std::copy(static_cast<const int64_t*>(src), static_cast<const int64_t*>(src) + length, break;
static_cast<T *>(rawPtr())); case DataType::Int64:
} std::copy(static_cast<const int64_t*>(src), static_cast<const int64_t*>(src) + length,
else if (srcDt == DataType::UInt64) { static_cast<T *>(rawPtr()));
std::copy(static_cast<const uint64_t*>(src), static_cast<const uint64_t*>(src) + length, break;
static_cast<T *>(rawPtr())); case DataType::UInt64:
} std::copy(static_cast<const uint64_t*>(src), static_cast<const uint64_t*>(src) + length,
else if (srcDt == DataType::Int32) { static_cast<T *>(rawPtr()));
std::copy(static_cast<const int32_t*>(src), static_cast<const int32_t*>(src) + length, break;
static_cast<T *>(rawPtr())); case DataType::Int32:
} std::copy(static_cast<const int32_t*>(src), static_cast<const int32_t*>(src) + length,
else if (srcDt == DataType::UInt32) { static_cast<T *>(rawPtr()));
std::copy(static_cast<const uint32_t*>(src), static_cast<const uint32_t*>(src) + length, break;
static_cast<T *>(rawPtr())); case DataType::UInt32:
} std::copy(static_cast<const uint32_t*>(src), static_cast<const uint32_t*>(src) + length,
else if (srcDt == DataType::Int16) { static_cast<T *>(rawPtr()));
std::copy(static_cast<const int16_t*>(src), static_cast<const int16_t*>(src) + length, break;
static_cast<T *>(rawPtr())); case DataType::Int16:
} std::copy(static_cast<const int16_t*>(src), static_cast<const int16_t*>(src) + length,
else if (srcDt == DataType::UInt16) { static_cast<T *>(rawPtr()));
std::copy(static_cast<const uint16_t*>(src), static_cast<const uint16_t*>(src) + length, break;
static_cast<T *>(rawPtr())); case DataType::UInt16:
} std::copy(static_cast<const uint16_t*>(src), static_cast<const uint16_t*>(src) + length,
else if (srcDt == DataType::Int8) { static_cast<T *>(rawPtr()));
std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, break;
static_cast<T *>(rawPtr())); case ataType::Int8:
} std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
else if (srcDt == DataType::UInt8) { static_cast<T *>(rawPtr()));
std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, break;
static_cast<T *>(rawPtr())); case DataType::UInt8:
} std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
else { static_cast<T *>(rawPtr()));
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type."); break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type.");
break;
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment