diff --git a/include/aidge/backend/TensorImpl.hpp b/include/aidge/backend/TensorImpl.hpp index 2345fb0d9a09cf3910393dd9d23ce1d2b90db489..981bb9e43e61de1b07d5ff2e97cccea78a19ac56 100644 --- a/include/aidge/backend/TensorImpl.hpp +++ b/include/aidge/backend/TensorImpl.hpp @@ -45,8 +45,9 @@ public: * Copy data from the same device. * @param src Pointer on current implementation device. * @param length Number of elements to copy. + * @param offset Destination offset (in number of elements). */ - virtual void copy(const void *src, NbElts_t length) = 0; + virtual void copy(const void *src, NbElts_t length, NbElts_t offset = 0) = 0; /** * Copy-convert data from the same device. @@ -92,6 +93,11 @@ public: virtual void* hostPtr() { return nullptr; }; virtual const void* hostPtr() const { return nullptr; }; + /** + * Get the device pointer with an offset (in number of elements). + */ + virtual void* getRawPtr(NbElts_t idx) = 0; + /** * Sets the device pointer. The previously owned data is deleted. * UNSAFE: directly setting the device pointer may lead to undefined behavior @@ -104,8 +110,7 @@ public: AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot set raw pointer for backend %s", mBackend); }; - virtual void* getRaw(std::size_t /*idx*/)=0; - + virtual std::size_t size() const = 0; // Storage size virtual std::size_t scalarSize() const = 0; // Size of one scalar (in bytes) constexpr const char *backend() const { return mBackend; } virtual ~TensorImpl() = default; diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index a5680f92760a615ea807a3a137da3c49d3652be7..3d518026f3f5266b81da5aa9ab65efc02c39a902 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -367,14 +367,14 @@ class Tensor : public Data, expectedType& get(std::size_t idx){ // TODO : add assert expected Type compatible with datatype // TODO : add assert idx < Size - return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx)); + return *reinterpret_cast<expectedType *>(mImpl->getRawPtr(idx)); } template <typename expectedType> const expectedType& get(std::size_t idx) const { // TODO : add assert expected Type compatible with datatype // TODO : add assert idx < Size - return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx)); + return *reinterpret_cast<expectedType *>(mImpl->getRawPtr(idx)); } template <typename expectedType> @@ -391,7 +391,7 @@ class Tensor : public Data, void set(std::size_t idx, expectedType value){ // TODO : add assert expected Type compatible with datatype // TODO : add assert idx < Size - expectedType* dataPtr = static_cast<expectedType*>(mImpl->getRaw(idx)); + expectedType* dataPtr = static_cast<expectedType*>(mImpl->getRawPtr(idx)); *dataPtr = value; }