Skip to content
Snippets Groups Projects
Commit 1d453e58 authored by Maxence Naud's avatar Maxence Naud
Browse files

Slight optimization of TensorImpl_cpu

parent b5e1d886
No related branches found
No related tags found
1 merge request!31Adding INT64 Tensor support
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_DATA_TENSORIMPL_H_ #ifndef AIDGE_CPU_DATA_TENSORIMPL_H_
#define AIDGE_CPU_DATA_TENSORIMPL_H_ #define AIDGE_CPU_DATA_TENSORIMPL_H_
...@@ -10,9 +21,10 @@ ...@@ -10,9 +21,10 @@
#include "aidge/utils/future_std/span.hpp" #include "aidge/utils/future_std/span.hpp"
namespace Aidge { namespace Aidge {
template <class T> template <class T>
class TensorImpl_cpu : public TensorImpl { class TensorImpl_cpu : public TensorImpl {
private: private:
const Tensor &mTensor; // Impl needs to access Tensor information, but is not const Tensor &mTensor; // Impl needs to access Tensor information, but is not
// supposed to change it! // supposed to change it!
/// Pointer to the data and its capacity /// Pointer to the data and its capacity
...@@ -20,7 +32,7 @@ class TensorImpl_cpu : public TensorImpl { ...@@ -20,7 +32,7 @@ class TensorImpl_cpu : public TensorImpl {
/// If this instance own the data, std::unique_ptr manages it /// If this instance own the data, std::unique_ptr manages it
std::unique_ptr<T[]> mDataOwner; std::unique_ptr<T[]> mDataOwner;
public: public:
static constexpr const char *Backend = "cpu"; static constexpr const char *Backend = "cpu";
TensorImpl_cpu(const Tensor &tensor) : TensorImpl(Backend), mTensor(tensor) {} TensorImpl_cpu(const Tensor &tensor) : TensorImpl(Backend), mTensor(tensor) {}
...@@ -31,8 +43,8 @@ class TensorImpl_cpu : public TensorImpl { ...@@ -31,8 +43,8 @@ class TensorImpl_cpu : public TensorImpl {
std::size_t i = 0; std::size_t i = 0;
for (; i < mTensor.size() && for (; i < mTensor.size() &&
mData[i] == typedOtherImpl.data()[i]; *(mData.data()+i) == *static_cast<const T*>(typedOtherImpl.rawPtr(i));
++i) { ++i) {
} }
return i == mTensor.size(); return i == mTensor.size();
} }
...@@ -41,23 +53,20 @@ class TensorImpl_cpu : public TensorImpl { ...@@ -41,23 +53,20 @@ class TensorImpl_cpu : public TensorImpl {
return std::make_unique<TensorImpl_cpu<T>>(tensor); return std::make_unique<TensorImpl_cpu<T>>(tensor);
} }
// native interface inline std::size_t size() const noexcept override final { return mData.size(); }
auto data() const -> decltype(mData.data()) { return mData.data(); } inline std::size_t scalarSize() const noexcept override final { return sizeof(T); }
std::size_t size() const override { return mData.size(); }
std::size_t scalarSize() const override { return sizeof(T); }
void setDevice(DeviceIdx_t device) override { void setDevice(DeviceIdx_t device) override final {
AIDGE_ASSERT(device == 0, "device cannot be != 0 for CPU backend"); AIDGE_ASSERT(device == 0, "device cannot be != 0 for CPU backend");
} }
void copy(const void *src, NbElts_t length, NbElts_t offset = 0) override { void copy(const void *src, NbElts_t length, NbElts_t offset = 0) override final {
AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity"); AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity");
std::copy(static_cast<const T *>(src), static_cast<const T *>(src) + length, std::copy(static_cast<const T *>(src), static_cast<const T *>(src) + length,
static_cast<T *>(rawPtr()) + offset); static_cast<T *>(rawPtr()) + offset);
} }
void copyCast(const void *src, NbElts_t length, const DataType srcDt) override { void copyCast(const void *src, NbElts_t length, const DataType srcDt) override final {
if (length == 0) { if (length == 0) {
return; return;
} }
...@@ -101,7 +110,7 @@ class TensorImpl_cpu : public TensorImpl { ...@@ -101,7 +110,7 @@ class TensorImpl_cpu : public TensorImpl {
std::copy(static_cast<const uint16_t*>(src), static_cast<const uint16_t*>(src) + length, std::copy(static_cast<const uint16_t*>(src), static_cast<const uint16_t*>(src) + length,
static_cast<T *>(rawPtr())); static_cast<T *>(rawPtr()));
break; break;
case ataType::Int8: case DataType::Int8:
std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
static_cast<T *>(rawPtr())); static_cast<T *>(rawPtr()));
break; break;
...@@ -115,39 +124,39 @@ class TensorImpl_cpu : public TensorImpl { ...@@ -115,39 +124,39 @@ class TensorImpl_cpu : public TensorImpl {
} }
} }
void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, DeviceIdx_t>& device) override { void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, DeviceIdx_t>& device) override final {
AIDGE_ASSERT(device.first == Backend, "backend must match"); AIDGE_ASSERT(device.first == Backend, "backend must match");
AIDGE_ASSERT(device.second == 0, "device cannot be != 0 for CPU backend"); AIDGE_ASSERT(device.second == 0, "device cannot be != 0 for CPU backend");
copy(src, length); copy(src, length);
} }
void copyFromHost(const void *src, NbElts_t length) override { inline void copyFromHost(const void *src, NbElts_t length) override final {
copy(src, length); copy(src, length);
} }
void copyToHost(void *dst, NbElts_t length) const override { void copyToHost(void *dst, NbElts_t length) const override final {
AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity"); AIDGE_ASSERT(length <= mData.size() || length <= mTensor.size(), "copy length is above capacity");
const T* src = static_cast<const T*>(rawPtr()); const T* src = static_cast<const T*>(rawPtr());
std::copy(static_cast<const T *>(src), static_cast<const T *>(src) + length, std::copy(static_cast<const T *>(src), static_cast<const T *>(src) + length,
static_cast<T *>(dst)); static_cast<T *>(dst));
} }
void *rawPtr(NbElts_t offset = 0) override { void *rawPtr(NbElts_t offset = 0) override final {
lazyInit(); lazyInit();
return (mData.data() + offset); return (mData.data() + offset);
}; };
const void *rawPtr(NbElts_t offset = 0) const override { const void *rawPtr(NbElts_t offset = 0) const override final {
AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const rawPtr"); AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const rawPtr");
return (mData.data() + offset); return (mData.data() + offset);
}; };
void *hostPtr(NbElts_t offset = 0) override { void *hostPtr(NbElts_t offset = 0) override final {
lazyInit(); lazyInit();
return (mData.data() + offset); return (mData.data() + offset);
}; };
const void *hostPtr(NbElts_t offset = 0) const override { const void *hostPtr(NbElts_t offset = 0) const override final {
AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const hostPtr"); AIDGE_ASSERT(mData.size() >= mTensor.size(), "accessing uninitialized const hostPtr");
return (mData.data() + offset); return (mData.data() + offset);
}; };
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment