Skip to content
Snippets Groups Projects
Commit b2170f4a authored by Maxence Naud's avatar Maxence Naud
Browse files

Change 'TensorImpl' mBackend member variable type and move member functions to source file

- type changed from 'const char*' to 'std::string'
parent b50429c2
No related branches found
No related tags found
No related merge requests found
......@@ -72,7 +72,7 @@ private:
class TensorImpl {
protected:
const char *mBackend;
const std::string mBackend;
/// @brief Device id.
const DeviceIdx_t mDevice;
/// Number of elements (to be) stored.
......@@ -81,7 +81,7 @@ protected:
public:
TensorImpl() = delete;
TensorImpl(const char *backend, DeviceIdx_t device, std::vector<DimSize_t> dims)
TensorImpl(const std::string& backend, DeviceIdx_t device, std::vector<DimSize_t> dims)
: mBackend(backend),
mDevice(device)
{
......@@ -97,7 +97,7 @@ public:
* Return the (backend, device) pair for this implementation.
*/
std::pair<std::string, DeviceIdx_t> device() const noexcept {
return std::make_pair(std::string(mBackend), mDevice);
return std::make_pair(mBackend, mDevice);
}
/**
......@@ -194,7 +194,7 @@ public:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Function not implented");
}
constexpr const char *backend() const { return mBackend; }
const std::string backend() const { return mBackend; }
/**
* @brief Copy from another backend.
......
......@@ -14,7 +14,6 @@
#include "aidge/backend/TensorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/data/half.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
......@@ -31,21 +30,12 @@ private:
std::unique_ptr<T[]> mDataOwner;
public:
static constexpr const char *Backend = "cpu";
static const std::string Backend;
public:
TensorImpl_cpu(DeviceIdx_t device, std::vector<DimSize_t> dims) : TensorImpl(Backend, device, dims) {}
bool operator==(const TensorImpl &otherImpl) const override final {
const auto& typedOtherImpl = reinterpret_cast<const TensorImpl_cpu<T> &>(otherImpl);
AIDGE_INTERNAL_ASSERT(typedOtherImpl.size() >= mNbElts);
std::size_t i = 0;
for (; i < mNbElts &&
*static_cast<const T*>(rawPtr(i)) == *static_cast<const T*>(typedOtherImpl.rawPtr(i));
++i) {
}
return i == mNbElts;
}
bool operator==(const TensorImpl &other) const override final;
static std::shared_ptr<TensorImpl_cpu> create(DeviceIdx_t device, std::vector<DimSize_t> dims) {
return std::make_shared<TensorImpl_cpu<T>>(device, dims);
......@@ -53,14 +43,7 @@ public:
inline std::size_t scalarSize() const noexcept override final { return sizeof(T); }
void zeros() override final {
if (mData.empty()) {
lazyInit();
}
for (std::size_t i = 0; i < mData.size(); ++i) {
*(mData.data() + i) = T(0);
}
}
void zeros() override final;
void copy(const void *src, NbElts_t length, NbElts_t offset = 0) override final {
const T* srcT = static_cast<const T *>(src);
......@@ -71,64 +54,7 @@ public:
std::copy(srcT, srcT + length, dstT);
}
void copyCast(const void *src, const DataType srcDt, NbElts_t length, NbElts_t offset = 0) override final {
if (length == 0) {
return;
}
T* dstT = static_cast<T *>(rawPtr(offset));
AIDGE_ASSERT(length <= mData.size() || length <= mNbElts, "copy length is above capacity");
switch (srcDt)
{
case DataType::Float64:
std::copy(static_cast<const double*>(src), static_cast<const double*>(src) + length,
dstT);
break;
case DataType::Float32:
std::copy(static_cast<const float*>(src), static_cast<const float*>(src) + length,
dstT);
break;
case DataType::Float16:
std::copy(static_cast<const half_float::half*>(src), static_cast<const half_float::half*>(src) + length,
dstT);
break;
case DataType::Int64:
std::copy(static_cast<const int64_t*>(src), static_cast<const int64_t*>(src) + length,
dstT);
break;
case DataType::UInt64:
std::copy(static_cast<const uint64_t*>(src), static_cast<const uint64_t*>(src) + length,
dstT);
break;
case DataType::Int32:
std::copy(static_cast<const int32_t*>(src), static_cast<const int32_t*>(src) + length,
dstT);
break;
case DataType::UInt32:
std::copy(static_cast<const uint32_t*>(src), static_cast<const uint32_t*>(src) + length,
dstT);
break;
case DataType::Int16:
std::copy(static_cast<const int16_t*>(src), static_cast<const int16_t*>(src) + length,
dstT);
break;
case DataType::UInt16:
std::copy(static_cast<const uint16_t*>(src), static_cast<const uint16_t*>(src) + length,
dstT);
break;
case DataType::Int8:
std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
dstT);
break;
case DataType::UInt8:
std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
dstT);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type.");
break;
}
}
void copyCast(const void *src, const DataType srcDt, NbElts_t length, NbElts_t offset = 0) override final;
void copyFromDevice(const void *src, const std::pair<std::string, DeviceIdx_t>& device, NbElts_t length, NbElts_t offset = 0) override final {
AIDGE_ASSERT(device.first == Backend, "backend must match");
......@@ -185,6 +111,10 @@ private:
}
};
template <typename T>
const std::string TensorImpl_cpu<T>::Backend = "cpu";
namespace {
static Registrar<Tensor> registrarTensorImpl_cpu_Float64(
{"cpu", DataType::Float64}, Aidge::TensorImpl_cpu<double>::create);
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/backend/cpu/data/TensorImpl.hpp"
#include <algorithm> // std::copy
#include <cstddef> // std::size_t
#include <cstdint> // std::uint8_t, std::int8_t, std::uint16_t, std::int16_t,
// std::uint32_t, std::int32_t, std::uint64_t, std::int64_t
#include <string>
#include "aidge/data/half.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
template <typename T>
bool Aidge::TensorImpl_cpu<T>::operator==(const Aidge::TensorImpl &other) const {
const auto& typedOtherImpl = reinterpret_cast<const TensorImpl_cpu<T>&>(other);
AIDGE_INTERNAL_ASSERT(typedOtherImpl.size() >= mNbElts);
std::size_t i = 0;
for (;
i < mNbElts &&
*static_cast<const T*>(rawPtr(i)) == *static_cast<const T*>(typedOtherImpl.rawPtr(i));
++i)
{}
return i == mNbElts;
}
template <typename T>
void Aidge::TensorImpl_cpu<T>::zeros() {
if (mData.empty()) {
lazyInit();
}
for (std::size_t i = 0; i < mData.size(); ++i) {
*(mData.data() + i) = T(0);
}
}
template <typename T>
void Aidge::TensorImpl_cpu<T>::copyCast(const void *src, const Aidge::DataType srcDt, Aidge::NbElts_t length, Aidge::NbElts_t offset) {
if (length == 0) {
return;
}
T* dstT = static_cast<T *>(rawPtr(offset));
AIDGE_ASSERT(length <= mData.size() || length <= mNbElts, "copy length is above capacity");
switch (srcDt)
{
case DataType::Float64:
std::copy(static_cast<const double*>(src), static_cast<const double*>(src) + length,
dstT);
break;
case DataType::Float32:
std::copy(static_cast<const float*>(src), static_cast<const float*>(src) + length,
dstT);
break;
case DataType::Float16:
std::copy(static_cast<const half_float::half*>(src), static_cast<const half_float::half*>(src) + length,
dstT);
break;
case DataType::Int64:
std::copy(static_cast<const int64_t*>(src), static_cast<const int64_t*>(src) + length,
dstT);
break;
case DataType::UInt64:
std::copy(static_cast<const uint64_t*>(src), static_cast<const uint64_t*>(src) + length,
dstT);
break;
case DataType::Int32:
std::copy(static_cast<const int32_t*>(src), static_cast<const int32_t*>(src) + length,
dstT);
break;
case DataType::UInt32:
std::copy(static_cast<const uint32_t*>(src), static_cast<const uint32_t*>(src) + length,
dstT);
break;
case DataType::Int16:
std::copy(static_cast<const int16_t*>(src), static_cast<const int16_t*>(src) + length,
dstT);
break;
case DataType::UInt16:
std::copy(static_cast<const uint16_t*>(src), static_cast<const uint16_t*>(src) + length,
dstT);
break;
case DataType::Int8:
std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
dstT);
break;
case DataType::UInt8:
std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
dstT);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type.");
break;
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment