Skip to content
Snippets Groups Projects
Commit fca8066f authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added offset to copy + renamed getRaw to getRawPtr

parent 2d9f6d80
No related branches found
No related tags found
No related merge requests found
...@@ -45,8 +45,9 @@ public: ...@@ -45,8 +45,9 @@ public:
* Copy data from the same device. * Copy data from the same device.
* @param src Pointer on current implementation device. * @param src Pointer on current implementation device.
* @param length Number of elements to copy. * @param length Number of elements to copy.
* @param offset Destination offset (in number of elements).
*/ */
virtual void copy(const void *src, NbElts_t length) = 0; virtual void copy(const void *src, NbElts_t length, NbElts_t offset = 0) = 0;
/** /**
* Copy-convert data from the same device. * Copy-convert data from the same device.
...@@ -92,6 +93,11 @@ public: ...@@ -92,6 +93,11 @@ public:
virtual void* hostPtr() { return nullptr; }; virtual void* hostPtr() { return nullptr; };
virtual const void* hostPtr() const { return nullptr; }; virtual const void* hostPtr() const { return nullptr; };
/**
* Get the device pointer with an offset (in number of elements).
*/
virtual void* getRawPtr(NbElts_t idx) = 0;
/** /**
* Sets the device pointer. The previously owned data is deleted. * Sets the device pointer. The previously owned data is deleted.
* UNSAFE: directly setting the device pointer may lead to undefined behavior * UNSAFE: directly setting the device pointer may lead to undefined behavior
...@@ -104,8 +110,7 @@ public: ...@@ -104,8 +110,7 @@ public:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot set raw pointer for backend %s", mBackend); AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot set raw pointer for backend %s", mBackend);
}; };
virtual void* getRaw(std::size_t /*idx*/)=0; virtual std::size_t size() const = 0; // Storage size
virtual std::size_t scalarSize() const = 0; // Size of one scalar (in bytes) virtual std::size_t scalarSize() const = 0; // Size of one scalar (in bytes)
constexpr const char *backend() const { return mBackend; } constexpr const char *backend() const { return mBackend; }
virtual ~TensorImpl() = default; virtual ~TensorImpl() = default;
......
...@@ -367,14 +367,14 @@ class Tensor : public Data, ...@@ -367,14 +367,14 @@ class Tensor : public Data,
expectedType& get(std::size_t idx){ expectedType& get(std::size_t idx){
// TODO : add assert expected Type compatible with datatype // TODO : add assert expected Type compatible with datatype
// TODO : add assert idx < Size // TODO : add assert idx < Size
return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx)); return *reinterpret_cast<expectedType *>(mImpl->getRawPtr(idx));
} }
template <typename expectedType> template <typename expectedType>
const expectedType& get(std::size_t idx) const { const expectedType& get(std::size_t idx) const {
// TODO : add assert expected Type compatible with datatype // TODO : add assert expected Type compatible with datatype
// TODO : add assert idx < Size // TODO : add assert idx < Size
return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx)); return *reinterpret_cast<expectedType *>(mImpl->getRawPtr(idx));
} }
template <typename expectedType> template <typename expectedType>
...@@ -391,7 +391,7 @@ class Tensor : public Data, ...@@ -391,7 +391,7 @@ class Tensor : public Data,
void set(std::size_t idx, expectedType value){ void set(std::size_t idx, expectedType value){
// TODO : add assert expected Type compatible with datatype // TODO : add assert expected Type compatible with datatype
// TODO : add assert idx < Size // TODO : add assert idx < Size
expectedType* dataPtr = static_cast<expectedType*>(mImpl->getRaw(idx)); expectedType* dataPtr = static_cast<expectedType*>(mImpl->getRawPtr(idx));
*dataPtr = value; *dataPtr = value;
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment