Skip to content
Snippets Groups Projects
Commit fca8066f authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added offset to copy + renamed getRaw to getRawPtr

parent 2d9f6d80
No related branches found
No related tags found
No related merge requests found
......@@ -45,8 +45,9 @@ public:
* Copy data from the same device.
* @param src Pointer on current implementation device.
* @param length Number of elements to copy.
* @param offset Destination offset (in number of elements).
*/
virtual void copy(const void *src, NbElts_t length) = 0;
virtual void copy(const void *src, NbElts_t length, NbElts_t offset = 0) = 0;
/**
* Copy-convert data from the same device.
......@@ -92,6 +93,11 @@ public:
virtual void* hostPtr() { return nullptr; };
virtual const void* hostPtr() const { return nullptr; };
/**
* Get the device pointer with an offset (in number of elements).
*/
virtual void* getRawPtr(NbElts_t idx) = 0;
/**
* Sets the device pointer. The previously owned data is deleted.
* UNSAFE: directly setting the device pointer may lead to undefined behavior
......@@ -104,8 +110,7 @@ public:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot set raw pointer for backend %s", mBackend);
};
virtual void* getRaw(std::size_t /*idx*/)=0;
virtual std::size_t size() const = 0; // Storage size
virtual std::size_t scalarSize() const = 0; // Size of one scalar (in bytes)
constexpr const char *backend() const { return mBackend; }
virtual ~TensorImpl() = default;
......
......@@ -367,14 +367,14 @@ class Tensor : public Data,
expectedType& get(std::size_t idx){
// TODO : add assert expected Type compatible with datatype
// TODO : add assert idx < Size
return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx));
return *reinterpret_cast<expectedType *>(mImpl->getRawPtr(idx));
}
template <typename expectedType>
const expectedType& get(std::size_t idx) const {
// TODO : add assert expected Type compatible with datatype
// TODO : add assert idx < Size
return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx));
return *reinterpret_cast<expectedType *>(mImpl->getRawPtr(idx));
}
template <typename expectedType>
......@@ -391,7 +391,7 @@ class Tensor : public Data,
void set(std::size_t idx, expectedType value){
// TODO : add assert expected Type compatible with datatype
// TODO : add assert idx < Size
expectedType* dataPtr = static_cast<expectedType*>(mImpl->getRaw(idx));
expectedType* dataPtr = static_cast<expectedType*>(mImpl->getRawPtr(idx));
*dataPtr = value;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment