Skip to content
Snippets Groups Projects

Adding INT64 Tensor support

Merged Cyril Moineau requested to merge dev into master
1 unresolved thread
 
/********************************************************************************
 
* Copyright (c) 2023 CEA-List
 
*
 
* This program and the accompanying materials are made available under the
 
* terms of the Eclipse Public License 2.0 which is available at
 
* http://www.eclipse.org/legal/epl-2.0.
 
*
 
* SPDX-License-Identifier: EPL-2.0
 
*
 
********************************************************************************/
 
#ifndef AIDGE_CPU_DATA_TENSORIMPL_H_
#ifndef AIDGE_CPU_DATA_TENSORIMPL_H_
#define AIDGE_CPU_DATA_TENSORIMPL_H_
#define AIDGE_CPU_DATA_TENSORIMPL_H_
@@ -10,9 +21,10 @@
@@ -10,9 +21,10 @@
#include "aidge/utils/future_std/span.hpp"
#include "aidge/utils/future_std/span.hpp"
namespace Aidge {
namespace Aidge {
 
template <class T>
template <class T>
class TensorImpl_cpu : public TensorImpl {
class TensorImpl_cpu : public TensorImpl {
private:
private:
const Tensor &mTensor; // Impl needs to access Tensor information, but is not
const Tensor &mTensor; // Impl needs to access Tensor information, but is not
// supposed to change it!
// supposed to change it!
/// Pointer to the data and its capacity
/// Pointer to the data and its capacity
@@ -20,19 +32,19 @@ class TensorImpl_cpu : public TensorImpl {
@@ -20,19 +32,19 @@ class TensorImpl_cpu : public TensorImpl {
/// If this instance own the data, std::unique_ptr manages it
/// If this instance own the data, std::unique_ptr manages it
std::unique_ptr<T[]> mDataOwner;
std::unique_ptr<T[]> mDataOwner;
public:
public:
static constexpr const char *Backend = "cpu";
static constexpr const char *Backend = "cpu";
TensorImpl_cpu(const Tensor &tensor) : TensorImpl(Backend), mTensor(tensor) {}
TensorImpl_cpu(const Tensor &tensor) : TensorImpl(Backend), mTensor(tensor) {}
bool operator==(const TensorImpl &otherImpl) const override final {
bool operator==(const TensorImpl &otherImpl) const override final {
const auto& typedOtherImpl = reinterpret_cast<const TensorImpl_cpu<T> &>(otherImpl);
const auto& typedOtherImpl = reinterpret_cast<const TensorImpl_cpu<T> &>(otherImpl);
AIDGE_INTERNAL_ASSERT(typedOtherImpl.data().size() >= mTensor.size());
AIDGE_INTERNAL_ASSERT(typedOtherImpl.size() >= mTensor.size());
std::size_t i = 0;
std::size_t i = 0;
for (; i < mTensor.size() &&
for (; i < mTensor.size() &&
mData[i] == typedOtherImpl.data()[i];
*(mData.data()+i) == *static_cast<const T*>(typedOtherImpl.rawPtr(i));
++i) {
++i) {
}
}
return i == mTensor.size();
return i == mTensor.size();
}
}
@@ -41,110 +53,110 @@ class TensorImpl_cpu : public TensorImpl {
@@ -41,110 +53,110 @@ class TensorImpl_cpu : public TensorImpl {
return std::make_unique<TensorImpl_cpu<T>>(tensor);
return std::make_unique<TensorImpl_cpu<T>>(tensor);
}
}
// native interface
inline std::size_t size() const noexcept override final { return mData.size(); }
const future_std::span<T>& data() const { return mData; }
inline std::size_t scalarSize() const noexcept override final { return sizeof(T); }
std::size_t size() const override { return mData.size(); }
std::size_t scalarSize() const override { return sizeof(T); }
void setDevice(DeviceIdx_t device) override {
void setDevice(DeviceIdx_t device) override final {
AIDGE_ASSERT(device == 0, "device cannot be != 0 for CPU backend");
AIDGE_ASSERT(device == 0, "device cannot be != 0 for CPU backend");
}
}
void copy(const void *src, NbElts_t length, NbElts_t offset = 0) override {
void copy(const void *src, NbElts_t length, NbElts_t offset = 0) override final {
AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity");
AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity");
std::copy(static_cast<const T *>(src), static_cast<const T *>(src) + length,
std::copy(static_cast<const T *>(src), static_cast<const T *>(src) + length,
static_cast<T *>(rawPtr()) + offset);
static_cast<T *>(rawPtr()) + offset);
}
}
void copyCast(const void *src, NbElts_t length, const DataType srcDt) override {
void copyCast(const void *src, NbElts_t length, const DataType srcDt) override final {
if (length == 0) {
if (length == 0) {
return;
return;
}
}
AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity");
AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity");
if (srcDt == DataType::Float64) {
switch (srcDt)
std::copy(static_cast<const double*>(src), static_cast<const double*>(src) + length,
{
static_cast<T *>(rawPtr()));
case DataType::Float64:
}
std::copy(static_cast<const double*>(src), static_cast<const double*>(src) + length,
else if (srcDt == DataType::Float32) {
static_cast<T *>(rawPtr()));
std::copy(static_cast<const float*>(src), static_cast<const float*>(src) + length,
break;
static_cast<T *>(rawPtr()));
case DataType::Float32:
}
std::copy(static_cast<const float*>(src), static_cast<const float*>(src) + length,
else if (srcDt == DataType::Float16) {
static_cast<T *>(rawPtr()));
std::copy(static_cast<const half_float::half*>(src), static_cast<const half_float::half*>(src) + length,
break;
static_cast<T *>(rawPtr()));
case DataType::Float16:
}
std::copy(static_cast<const half_float::half*>(src), static_cast<const half_float::half*>(src) + length,
else if (srcDt == DataType::Int64) {
static_cast<T *>(rawPtr()));
std::copy(static_cast<const int64_t*>(src), static_cast<const int64_t*>(src) + length,
break;
static_cast<T *>(rawPtr()));
case DataType::Int64:
}
std::copy(static_cast<const int64_t*>(src), static_cast<const int64_t*>(src) + length,
else if (srcDt == DataType::UInt64) {
static_cast<T *>(rawPtr()));
std::copy(static_cast<const uint64_t*>(src), static_cast<const uint64_t*>(src) + length,
break;
static_cast<T *>(rawPtr()));
case DataType::UInt64:
}
std::copy(static_cast<const uint64_t*>(src), static_cast<const uint64_t*>(src) + length,
else if (srcDt == DataType::Int32) {
static_cast<T *>(rawPtr()));
std::copy(static_cast<const int32_t*>(src), static_cast<const int32_t*>(src) + length,
break;
static_cast<T *>(rawPtr()));
case DataType::Int32:
}
std::copy(static_cast<const int32_t*>(src), static_cast<const int32_t*>(src) + length,
else if (srcDt == DataType::UInt32) {
static_cast<T *>(rawPtr()));
std::copy(static_cast<const uint32_t*>(src), static_cast<const uint32_t*>(src) + length,
break;
static_cast<T *>(rawPtr()));
case DataType::UInt32:
}
std::copy(static_cast<const uint32_t*>(src), static_cast<const uint32_t*>(src) + length,
else if (srcDt == DataType::Int16) {
static_cast<T *>(rawPtr()));
std::copy(static_cast<const int16_t*>(src), static_cast<const int16_t*>(src) + length,
break;
static_cast<T *>(rawPtr()));
case DataType::Int16:
}
std::copy(static_cast<const int16_t*>(src), static_cast<const int16_t*>(src) + length,
else if (srcDt == DataType::UInt16) {
static_cast<T *>(rawPtr()));
std::copy(static_cast<const uint16_t*>(src), static_cast<const uint16_t*>(src) + length,
break;
static_cast<T *>(rawPtr()));
case DataType::UInt16:
}
std::copy(static_cast<const uint16_t*>(src), static_cast<const uint16_t*>(src) + length,
else if (srcDt == DataType::Int8) {
static_cast<T *>(rawPtr()));
std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
break;
static_cast<T *>(rawPtr()));
case DataType::Int8:
}
std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
else if (srcDt == DataType::UInt8) {
static_cast<T *>(rawPtr()));
std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
break;
static_cast<T *>(rawPtr()));
case DataType::UInt8:
}
std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
else {
static_cast<T *>(rawPtr()));
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type.");
break;
 
default:
 
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type.");
 
break;
}
}
}
}
void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, DeviceIdx_t>& device) override {
void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, DeviceIdx_t>& device) override final {
AIDGE_ASSERT(device.first == Backend, "backend must match");
AIDGE_ASSERT(device.first == Backend, "backend must match");
AIDGE_ASSERT(device.second == 0, "device cannot be != 0 for CPU backend");
AIDGE_ASSERT(device.second == 0, "device cannot be != 0 for CPU backend");
copy(src, length);
copy(src, length);
}
}
void copyFromHost(const void *src, NbElts_t length) override {
inline void copyFromHost(const void *src, NbElts_t length) override final {
copy(src, length);
copy(src, length);
}
}
void copyToHost(void *dst, NbElts_t length) const override {
void copyToHost(void *dst, NbElts_t length) const override final {
AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity");
AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity");
const T* src = static_cast<const T*>(rawPtr());
const T* src = static_cast<const T*>(rawPtr());
std::copy(static_cast<const T *>(src), static_cast<const T *>(src) + length,
std::copy(static_cast<const T *>(src), static_cast<const T *>(src) + length,
static_cast<T *>(dst));
static_cast<T *>(dst));
}
}
void *rawPtr(NbElts_t offset = 0) override {
void *rawPtr(NbElts_t offset = 0) override final {
lazyInit();
lazyInit();
return (mData.data() + offset);
return (mData.data() + offset);
};
};
const void *rawPtr(NbElts_t offset = 0) const override {
const void *rawPtr(NbElts_t offset = 0) const override final {
AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const rawPtr");
AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const rawPtr");
return (mData.data() + offset);
return (mData.data() + offset);
};
};
void *hostPtr(NbElts_t offset = 0) override {
void *hostPtr(NbElts_t offset = 0) override final {
lazyInit();
lazyInit();
return (mData.data() + offset);
return (mData.data() + offset);
};
};
const void *hostPtr(NbElts_t offset = 0) const override {
const void *hostPtr(NbElts_t offset = 0) const override final {
AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const hostPtr");
AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const hostPtr");
return (mData.data() + offset);
return (mData.data() + offset);
};
};
@@ -177,6 +189,8 @@ static Registrar<Tensor> registrarTensorImpl_cpu_Float16(
@@ -177,6 +189,8 @@ static Registrar<Tensor> registrarTensorImpl_cpu_Float16(
{"cpu", DataType::Float16}, Aidge::TensorImpl_cpu<half_float::half>::create);
{"cpu", DataType::Float16}, Aidge::TensorImpl_cpu<half_float::half>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Int32(
static Registrar<Tensor> registrarTensorImpl_cpu_Int32(
{"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int>::create);
{"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int>::create);
 
static Registrar<Tensor> registrarTensorImpl_cpu_Int64(
 
{"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<long>::create);
} // namespace
} // namespace
} // namespace Aidge
} // namespace Aidge
Loading