From fca8066f1b268cf445583f0836e5dd91c6bfff77 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Thu, 4 Jan 2024 11:43:02 +0100 Subject: [PATCH] Added offset to copy + renamed getRaw to getRawPtr --- include/aidge/backend/TensorImpl.hpp | 11 ++++++++--- include/aidge/data/Tensor.hpp | 6 +++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/include/aidge/backend/TensorImpl.hpp b/include/aidge/backend/TensorImpl.hpp index 2345fb0d9..981bb9e43 100644 --- a/include/aidge/backend/TensorImpl.hpp +++ b/include/aidge/backend/TensorImpl.hpp @@ -45,8 +45,9 @@ public: * Copy data from the same device. * @param src Pointer on current implementation device. * @param length Number of elements to copy. + * @param offset Destination offset (in number of elements). */ - virtual void copy(const void *src, NbElts_t length) = 0; + virtual void copy(const void *src, NbElts_t length, NbElts_t offset = 0) = 0; /** * Copy-convert data from the same device. @@ -92,6 +93,11 @@ public: virtual void* hostPtr() { return nullptr; }; virtual const void* hostPtr() const { return nullptr; }; + /** + * Get the device pointer with an offset (in number of elements). + */ + virtual void* getRawPtr(NbElts_t idx) = 0; + /** * Sets the device pointer. The previously owned data is deleted. * UNSAFE: directly setting the device pointer may lead to undefined behavior @@ -104,8 +110,7 @@ public: AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot set raw pointer for backend %s", mBackend); }; - virtual void* getRaw(std::size_t /*idx*/)=0; - + virtual std::size_t size() const = 0; // Storage size virtual std::size_t scalarSize() const = 0; // Size of one scalar (in bytes) constexpr const char *backend() const { return mBackend; } virtual ~TensorImpl() = default; diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index a5680f927..3d518026f 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -367,14 +367,14 @@ class Tensor : public Data, expectedType& get(std::size_t idx){ // TODO : add assert expected Type compatible with datatype // TODO : add assert idx < Size - return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx)); + return *reinterpret_cast<expectedType *>(mImpl->getRawPtr(idx)); } template <typename expectedType> const expectedType& get(std::size_t idx) const { // TODO : add assert expected Type compatible with datatype // TODO : add assert idx < Size - return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx)); + return *reinterpret_cast<expectedType *>(mImpl->getRawPtr(idx)); } template <typename expectedType> @@ -391,7 +391,7 @@ class Tensor : public Data, void set(std::size_t idx, expectedType value){ // TODO : add assert expected Type compatible with datatype // TODO : add assert idx < Size - expectedType* dataPtr = static_cast<expectedType*>(mImpl->getRaw(idx)); + expectedType* dataPtr = static_cast<expectedType*>(mImpl->getRawPtr(idx)); *dataPtr = value; } -- GitLab