Skip to content
Snippets Groups Projects
Commit cc621da3 authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'dev' into 'master'

Adding INT64 Tensor support

See merge request !31
parents 832d856b 624670bf
No related branches found
Tags v0.1.0
1 merge request!31Adding INT64 Tensor support
Pipeline #37359 passed
......@@ -16,7 +16,7 @@ class test_tensor(unittest.TestCase):
self.assertTrue("cpu" in aidge_core.Tensor.get_available_backends())
def test_numpy_int_to_tensor(self):
np_array = np.arange(9).reshape(1,1,3,3)
np_array = np.arange(9).reshape(1,1,3,3).astype(np.int32)
# Numpy -> Tensor
t = aidge_core.Tensor(np_array)
self.assertEqual(t.dtype(), aidge_core.DataType.Int32)
......@@ -35,6 +35,16 @@ class test_tensor(unittest.TestCase):
for i,j in zip(t.dims(), nnarray.shape):
self.assertEqual(i,j)
def test_numpy_int64_to_tensor(self):
np_array = np.arange(9).reshape(1,1,3,3).astype(np.int64)
# Numpy -> Tensor
t = aidge_core.Tensor(np_array)
self.assertEqual(t.dtype(), aidge_core.DataType.Int64)
for i_t, i_n in zip(t, np_array.flatten()):
self.assertTrue(i_t == i_n)
for i,j in zip(t.dims(), np_array.shape):
self.assertEqual(i,j)
def test_numpy_float_to_tensor(self):
t = aidge_core.Tensor()
np_array = np.random.rand(1, 1, 3, 3).astype(np.float32)
......@@ -49,7 +59,7 @@ class test_tensor(unittest.TestCase):
def test_get_set(self):
dims = [2,2,2]
np_array = np.arange(8).reshape(dims)
np_array = np.arange(8).reshape(dims).astype(np.int32)
# Numpy -> Tensor
t = aidge_core.Tensor(np_array)
for i in range(8):
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_DATA_TENSORIMPL_H_
#define AIDGE_CPU_DATA_TENSORIMPL_H_
......@@ -10,9 +21,10 @@
#include "aidge/utils/future_std/span.hpp"
namespace Aidge {
template <class T>
class TensorImpl_cpu : public TensorImpl {
private:
private:
const Tensor &mTensor; // Impl needs to access Tensor information, but is not
// supposed to change it!
/// Pointer to the data and its capacity
......@@ -20,19 +32,19 @@ class TensorImpl_cpu : public TensorImpl {
/// If this instance own the data, std::unique_ptr manages it
std::unique_ptr<T[]> mDataOwner;
public:
public:
static constexpr const char *Backend = "cpu";
TensorImpl_cpu(const Tensor &tensor) : TensorImpl(Backend), mTensor(tensor) {}
bool operator==(const TensorImpl &otherImpl) const override final {
const auto& typedOtherImpl = reinterpret_cast<const TensorImpl_cpu<T> &>(otherImpl);
AIDGE_INTERNAL_ASSERT(typedOtherImpl.data().size() >= mTensor.size());
AIDGE_INTERNAL_ASSERT(typedOtherImpl.size() >= mTensor.size());
std::size_t i = 0;
for (; i < mTensor.size() &&
mData[i] == typedOtherImpl.data()[i];
++i) {
*(mData.data()+i) == *static_cast<const T*>(typedOtherImpl.rawPtr(i));
++i) {
}
return i == mTensor.size();
}
......@@ -41,110 +53,110 @@ class TensorImpl_cpu : public TensorImpl {
return std::make_unique<TensorImpl_cpu<T>>(tensor);
}
// native interface
const future_std::span<T>& data() const { return mData; }
std::size_t size() const override { return mData.size(); }
std::size_t scalarSize() const override { return sizeof(T); }
inline std::size_t size() const noexcept override final { return mData.size(); }
inline std::size_t scalarSize() const noexcept override final { return sizeof(T); }
void setDevice(DeviceIdx_t device) override {
void setDevice(DeviceIdx_t device) override final {
AIDGE_ASSERT(device == 0, "device cannot be != 0 for CPU backend");
}
void copy(const void *src, NbElts_t length, NbElts_t offset = 0) override {
void copy(const void *src, NbElts_t length, NbElts_t offset = 0) override final {
AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity");
std::copy(static_cast<const T *>(src), static_cast<const T *>(src) + length,
static_cast<T *>(rawPtr()) + offset);
}
void copyCast(const void *src, NbElts_t length, const DataType srcDt) override {
void copyCast(const void *src, NbElts_t length, const DataType srcDt) override final {
if (length == 0) {
return;
}
AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity");
if (srcDt == DataType::Float64) {
std::copy(static_cast<const double*>(src), static_cast<const double*>(src) + length,
static_cast<T *>(rawPtr()));
}
else if (srcDt == DataType::Float32) {
std::copy(static_cast<const float*>(src), static_cast<const float*>(src) + length,
static_cast<T *>(rawPtr()));
}
else if (srcDt == DataType::Float16) {
std::copy(static_cast<const half_float::half*>(src), static_cast<const half_float::half*>(src) + length,
static_cast<T *>(rawPtr()));
}
else if (srcDt == DataType::Int64) {
std::copy(static_cast<const int64_t*>(src), static_cast<const int64_t*>(src) + length,
static_cast<T *>(rawPtr()));
}
else if (srcDt == DataType::UInt64) {
std::copy(static_cast<const uint64_t*>(src), static_cast<const uint64_t*>(src) + length,
static_cast<T *>(rawPtr()));
}
else if (srcDt == DataType::Int32) {
std::copy(static_cast<const int32_t*>(src), static_cast<const int32_t*>(src) + length,
static_cast<T *>(rawPtr()));
}
else if (srcDt == DataType::UInt32) {
std::copy(static_cast<const uint32_t*>(src), static_cast<const uint32_t*>(src) + length,
static_cast<T *>(rawPtr()));
}
else if (srcDt == DataType::Int16) {
std::copy(static_cast<const int16_t*>(src), static_cast<const int16_t*>(src) + length,
static_cast<T *>(rawPtr()));
}
else if (srcDt == DataType::UInt16) {
std::copy(static_cast<const uint16_t*>(src), static_cast<const uint16_t*>(src) + length,
static_cast<T *>(rawPtr()));
}
else if (srcDt == DataType::Int8) {
std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
static_cast<T *>(rawPtr()));
}
else if (srcDt == DataType::UInt8) {
std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
static_cast<T *>(rawPtr()));
}
else {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type.");
switch (srcDt)
{
case DataType::Float64:
std::copy(static_cast<const double*>(src), static_cast<const double*>(src) + length,
static_cast<T *>(rawPtr()));
break;
case DataType::Float32:
std::copy(static_cast<const float*>(src), static_cast<const float*>(src) + length,
static_cast<T *>(rawPtr()));
break;
case DataType::Float16:
std::copy(static_cast<const half_float::half*>(src), static_cast<const half_float::half*>(src) + length,
static_cast<T *>(rawPtr()));
break;
case DataType::Int64:
std::copy(static_cast<const int64_t*>(src), static_cast<const int64_t*>(src) + length,
static_cast<T *>(rawPtr()));
break;
case DataType::UInt64:
std::copy(static_cast<const uint64_t*>(src), static_cast<const uint64_t*>(src) + length,
static_cast<T *>(rawPtr()));
break;
case DataType::Int32:
std::copy(static_cast<const int32_t*>(src), static_cast<const int32_t*>(src) + length,
static_cast<T *>(rawPtr()));
break;
case DataType::UInt32:
std::copy(static_cast<const uint32_t*>(src), static_cast<const uint32_t*>(src) + length,
static_cast<T *>(rawPtr()));
break;
case DataType::Int16:
std::copy(static_cast<const int16_t*>(src), static_cast<const int16_t*>(src) + length,
static_cast<T *>(rawPtr()));
break;
case DataType::UInt16:
std::copy(static_cast<const uint16_t*>(src), static_cast<const uint16_t*>(src) + length,
static_cast<T *>(rawPtr()));
break;
case DataType::Int8:
std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
static_cast<T *>(rawPtr()));
break;
case DataType::UInt8:
std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
static_cast<T *>(rawPtr()));
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type.");
break;
}
}
void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, DeviceIdx_t>& device) override {
void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, DeviceIdx_t>& device) override final {
AIDGE_ASSERT(device.first == Backend, "backend must match");
AIDGE_ASSERT(device.second == 0, "device cannot be != 0 for CPU backend");
copy(src, length);
}
void copyFromHost(const void *src, NbElts_t length) override {
inline void copyFromHost(const void *src, NbElts_t length) override final {
copy(src, length);
}
void copyToHost(void *dst, NbElts_t length) const override {
void copyToHost(void *dst, NbElts_t length) const override final {
AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity");
const T* src = static_cast<const T*>(rawPtr());
std::copy(static_cast<const T *>(src), static_cast<const T *>(src) + length,
static_cast<T *>(dst));
}
void *rawPtr(NbElts_t offset = 0) override {
void *rawPtr(NbElts_t offset = 0) override final {
lazyInit();
return (mData.data() + offset);
};
const void *rawPtr(NbElts_t offset = 0) const override {
const void *rawPtr(NbElts_t offset = 0) const override final {
AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const rawPtr");
return (mData.data() + offset);
};
void *hostPtr(NbElts_t offset = 0) override {
void *hostPtr(NbElts_t offset = 0) override final {
lazyInit();
return (mData.data() + offset);
};
const void *hostPtr(NbElts_t offset = 0) const override {
const void *hostPtr(NbElts_t offset = 0) const override final {
AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const hostPtr");
return (mData.data() + offset);
};
......@@ -177,6 +189,8 @@ static Registrar<Tensor> registrarTensorImpl_cpu_Float16(
{"cpu", DataType::Float16}, Aidge::TensorImpl_cpu<half_float::half>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Int32(
{"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Int64(
{"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<long>::create);
} // namespace
} // namespace Aidge
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment