From 23d46d9ef0906d84724fb1ed0d034c90c44cafc2 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Sun, 3 Dec 2023 18:57:06 +0100 Subject: [PATCH] Initial concept for Convert operator (UNTESTED) --- include/aidge/backend/cpu/data/TensorImpl.hpp | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/include/aidge/backend/cpu/data/TensorImpl.hpp b/include/aidge/backend/cpu/data/TensorImpl.hpp index 012ff5af..377b4546 100644 --- a/include/aidge/backend/cpu/data/TensorImpl.hpp +++ b/include/aidge/backend/cpu/data/TensorImpl.hpp @@ -5,6 +5,7 @@ #include "aidge/data/Tensor.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" +#include "aidge/utils/ErrorHandling.hpp" namespace Aidge { template <class T> @@ -37,16 +38,87 @@ class TensorImpl_cpu : public TensorImpl { std::size_t scalarSize() const override { return sizeof(T); } + void setDevice(int device) override { + AIDGE_ASSERT(device == 0, "device cannot be != 0 for CPU backend"); + } + void copy(const void *src, NbElts_t length) override { std::copy(static_cast<const T *>(src), static_cast<const T *>(src) + length, static_cast<T *>(rawPtr())); } + void copyCast(const void *src, NbElts_t length, const DataType srcDt) override { + if (srcDt == DataType::Float64) { + std::copy(static_cast<const double*>(src), static_cast<const double*>(src) + length, + static_cast<T *>(rawPtr())); + } + else if (srcDt == DataType::Float32) { + std::copy(static_cast<const float*>(src), static_cast<const float*>(src) + length, + static_cast<T *>(rawPtr())); + } + else if (srcDt == DataType::Int64) { + std::copy(static_cast<const int64_t*>(src), static_cast<const int64_t*>(src) + length, + static_cast<T *>(rawPtr())); + } + else if (srcDt == DataType::UInt64) { + std::copy(static_cast<const uint64_t*>(src), static_cast<const uint64_t*>(src) + length, + static_cast<T *>(rawPtr())); + } + else if (srcDt == DataType::Int32) { + std::copy(static_cast<const int32_t*>(src), static_cast<const int32_t*>(src) + length, + static_cast<T *>(rawPtr())); + } + else if (srcDt == DataType::UInt32) { + std::copy(static_cast<const uint32_t*>(src), static_cast<const uint32_t*>(src) + length, + static_cast<T *>(rawPtr())); + } + else if (srcDt == DataType::Int16) { + std::copy(static_cast<const int16_t*>(src), static_cast<const int16_t*>(src) + length, + static_cast<T *>(rawPtr())); + } + else if (srcDt == DataType::UInt16) { + std::copy(static_cast<const uint16_t*>(src), static_cast<const uint16_t*>(src) + length, + static_cast<T *>(rawPtr())); + } + else if (srcDt == DataType::Int8) { + std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length, + static_cast<T *>(rawPtr())); + } + else if (srcDt == DataType::UInt8) { + std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length, + static_cast<T *>(rawPtr())); + } + else { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type."); + } + } + + void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, int>& device) override { + AIDGE_ASSERT(device.first == Backend, "backend must match"); + AIDGE_ASSERT(device.second == 0, "device cannot be != 0 for CPU backend"); + copy(src, length); + } + + void copyFromHost(const void *src, NbElts_t length) override { + copy(src, length); + } + + void copyToHost(void *dst, NbElts_t length) override { + const T* src = static_cast<const T*>(rawPtr()); + std::copy(static_cast<const T *>(src), static_cast<const T *>(src) + length, + static_cast<T *>(dst)); + } + void *rawPtr() override { lazyInit(mData); return mData.data(); }; + void *hostPtr() override { + lazyInit(mData); + return mData.data(); + }; + void* getRaw(std::size_t idx){ return static_cast<void*>(static_cast<T *>(rawPtr()) + idx); }; -- GitLab