Skip to content
Snippets Groups Projects
Commit 1bd36647 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'view' into 'dev'

New proposal for handling tensor views

See merge request eclipse/aidge/aidge_core!69
parents 9f7c89aa deea0be0
No related branches found
No related tags found
No related merge requests found
# Version 0.1.0 (January 23, 2024)
Initial release
...@@ -67,19 +67,13 @@ private: ...@@ -67,19 +67,13 @@ private:
class TensorImpl { class TensorImpl {
public: public:
TensorImpl() = delete; TensorImpl() = delete;
TensorImpl(const char *backend, DeviceIdx_t device = 0) : mBackend(backend), mDevice(device){}; TensorImpl(const char *backend, DeviceIdx_t device, NbElts_t length) : mBackend(backend), mDevice(device), mNbElts(length) {};
/** /**
* Return the (backend, device) pair for this implementation. * Return the (backend, device) pair for this implementation.
*/ */
std::pair<std::string, DeviceIdx_t> device() const { return std::make_pair(mBackend, mDevice); } std::pair<std::string, DeviceIdx_t> device() const { return std::make_pair(mBackend, mDevice); }
/**
* Set the device ID for current backend.
* @param device New device ID on current backend.
*/
virtual void setDevice(DeviceIdx_t device) = 0;
/** /**
* Copy data from the same device. * Copy data from the same device.
* @param src Pointer on current implementation device. * @param src Pointer on current implementation device.
...@@ -93,30 +87,34 @@ public: ...@@ -93,30 +87,34 @@ public:
* @param srcDt Source data type. * @param srcDt Source data type.
* @param src Pointer on current implementation device. * @param src Pointer on current implementation device.
* @param length Number of elements to copy. * @param length Number of elements to copy.
* @param offset Destination offset (in number of elements).
*/ */
virtual void copyCast(const void *src, NbElts_t length, const DataType srcDt) = 0; virtual void copyCast(const void *src, const DataType srcDt, NbElts_t length, NbElts_t offset = 0) = 0;
/** /**
* Copy data from an other device on the same backend. * Copy data from an other device on the same backend.
* @param device (backend, device) pair to copy from. The backend must match current implementation backend. * @param device (backend, device) pair to copy from. The backend must match current implementation backend.
* @param src Pointer on current implementation backend. * @param src Pointer on current implementation backend.
* @param length Number of elements to copy. * @param length Number of elements to copy.
* @param offset Destination offset (in number of elements).
*/ */
virtual void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, DeviceIdx_t>& device) = 0; virtual void copyFromDevice(const void *src, const std::pair<std::string, DeviceIdx_t>& device, NbElts_t length, NbElts_t offset = 0) = 0;
/** /**
* Copy data from host. * Copy data from host.
* @param src Host pointer to copy from. * @param src Host pointer to copy from.
* @param length Number of elements to copy. * @param length Number of elements to copy.
* @param offset Destination offset (in number of elements).
*/ */
virtual void copyFromHost(const void *src, NbElts_t length) = 0; virtual void copyFromHost(const void *src, NbElts_t length, NbElts_t offset = 0) = 0;
/** /**
* Copy data to host. * Copy data to host.
* @param src Host pointer to copy to. * @param src Host pointer to copy to.
* @param length Number of elements to copy. * @param length Number of elements to copy.
* @param offset Source offset (in number of elements).
*/ */
virtual void copyToHost(void *dst, NbElts_t length) const = 0; virtual void copyToHost(void *dst, NbElts_t length, NbElts_t offset = 0) const = 0;
/** /**
* Return the raw device pointer. * Return the raw device pointer.
...@@ -146,8 +144,22 @@ public: ...@@ -146,8 +144,22 @@ public:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot set raw pointer for backend %s", mBackend); AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot set raw pointer for backend %s", mBackend);
}; };
virtual std::size_t size() const = 0; // Storage size /**
virtual std::size_t scalarSize() const = 0; // Size of one scalar (in bytes) * Set the size, in number of elements, that must be stored.
*/
void resize(NbElts_t length) {
mNbElts = length;
}
/**
* Return the number of elements stored.
*/
inline std::size_t size() const noexcept { return mNbElts; }
/**
* Return the size (in bytes) of one element (scalar).
*/
virtual std::size_t scalarSize() const noexcept = 0;
constexpr const char *backend() const { return mBackend; } constexpr const char *backend() const { return mBackend; }
virtual ~TensorImpl() = default; virtual ~TensorImpl() = default;
virtual bool operator==(const TensorImpl &othImpl) const = 0; virtual bool operator==(const TensorImpl &othImpl) const = 0;
...@@ -156,12 +168,16 @@ public: ...@@ -156,12 +168,16 @@ public:
* Copy from another backend. * Copy from another backend.
* @param srcImpl Source TensorImpl to copy from. * @param srcImpl Source TensorImpl to copy from.
* @param length Number of elements of size scalarSize() to copy * @param length Number of elements of size scalarSize() to copy
* @param srcOffset Source offset (in number of elements).
* @param dstOffset Destination offset (in number of elements).
*/ */
void copyFrom(const TensorImpl& srcImpl, NbElts_t length); void copyFrom(const TensorImpl& srcImpl, NbElts_t length, NbElts_t srcOffset = 0, NbElts_t dstOffset = 0);
protected: protected:
const char *mBackend; const char *mBackend;
DeviceIdx_t mDevice; const DeviceIdx_t mDevice;
/// Number of elements (to be) stored
NbElts_t mNbElts;
}; };
} // namespace Aidge } // namespace Aidge
......
...@@ -32,15 +32,18 @@ namespace Aidge { ...@@ -32,15 +32,18 @@ namespace Aidge {
* Contains a pointer to an actual contiguous implementation of data. * Contains a pointer to an actual contiguous implementation of data.
*/ */
class Tensor : public Data, class Tensor : public Data,
public Registrable<Tensor, std::tuple<std::string, DataType>, std::unique_ptr<TensorImpl>(const Tensor &)> { public Registrable<Tensor, std::tuple<std::string, DataType>, std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)> {
private: private:
DataType mDataType; /** enum to specify data type. */ DataType mDataType; /** enum to specify data type. */
std::vector<DimSize_t> mDims; /** Dimensions of the tensor. */ std::vector<DimSize_t> mDims; /** Dimensions of the tensor. */
std::unique_ptr<TensorImpl> mImpl; /** Pointer to the actual data implementation. */ std::vector<DimSize_t> mStrides; /** Stride dimensions of the tensor. */
std::shared_ptr<TensorImpl> mImpl; /** Pointer to the actual data implementation. */
std::size_t mImplOffset = 0;
std::shared_ptr<Tensor> mGrad; /** Pointer to the associated gradient Tensor instance. */ std::shared_ptr<Tensor> mGrad; /** Pointer to the associated gradient Tensor instance. */
// Cached data // Cached data
std::size_t mSize = 0; /** Number of elements in the Tensor. */ std::size_t mSize = 0; /** Number of elements in the Tensor. */
bool mContiguous = true;
public: public:
static constexpr const char *Type = "Tensor"; static constexpr const char *Type = "Tensor";
...@@ -57,21 +60,29 @@ class Tensor : public Data, ...@@ -57,21 +60,29 @@ class Tensor : public Data,
} }
/** /**
* @brief Construct a new Tensor object copied from another one. * @brief Construct a new Tensor object from another one (shallow copy).
* Data memory is not copied, but shared between the new Tensor and the
* initial one.
*
* @param otherTensor * @param otherTensor
*/ */
Tensor(const Tensor& otherTensor) Tensor(const Tensor&) = default;
: Data(Type), Tensor(Tensor&&) = default;
mDataType(otherTensor.mDataType),
mDims(otherTensor.mDims), /**
mSize(otherTensor.mSize) * Perform a deep copy of the tensor.
{ */
if (otherTensor.hasImpl()) { Tensor clone() const {
mImpl = Registrar<Tensor>::create({otherTensor.mImpl->backend(), dataType()})(*this); Tensor newTensor(*this);
mImpl->setDevice(otherTensor.mImpl->device().second); if (!newTensor.isContiguous()) {
// Same backend, same device => directly use copy() newTensor.makeContiguous();
mImpl->copy(otherTensor.mImpl->rawPtr(), mSize); }
else {
std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), mDataType})(mImpl->device().second, mSize);
newImpl->copy(mImpl->rawPtr(mImplOffset), mSize);
newTensor.setImpl(newImpl);
} }
return newTensor;
} }
/** /**
...@@ -84,7 +95,8 @@ class Tensor : public Data, ...@@ -84,7 +95,8 @@ class Tensor : public Data,
: Data(Type), : Data(Type),
mDataType(NativeType<T>::type), mDataType(NativeType<T>::type),
mDims({SIZE_0}), mDims({SIZE_0}),
mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)), mStrides({1}),
mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0)),
mSize(SIZE_0) { mSize(SIZE_0) {
mImpl->copyFromHost(&arr.data[0], SIZE_0); mImpl->copyFromHost(&arr.data[0], SIZE_0);
} }
...@@ -93,9 +105,9 @@ class Tensor : public Data, ...@@ -93,9 +105,9 @@ class Tensor : public Data,
constexpr Tensor &operator=(Array1D<T, SIZE_0> &&arr) { constexpr Tensor &operator=(Array1D<T, SIZE_0> &&arr) {
resize({SIZE_0}); resize({SIZE_0});
if (!mImpl) { if (!mImpl) {
mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this); mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0);
} }
mImpl->copyFromHost(&arr.data[0], SIZE_0); mImpl->copyFromHost(&arr.data[0], SIZE_0, mImplOffset);
return *this; return *this;
} }
...@@ -110,7 +122,8 @@ class Tensor : public Data, ...@@ -110,7 +122,8 @@ class Tensor : public Data,
: Data(Type), : Data(Type),
mDataType(NativeType<T>::type), mDataType(NativeType<T>::type),
mDims({SIZE_0, SIZE_1}), mDims({SIZE_0, SIZE_1}),
mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)), mStrides({SIZE_1, 1}),
mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0 * SIZE_1)),
mSize(SIZE_0 * SIZE_1) { mSize(SIZE_0 * SIZE_1) {
mImpl->copyFromHost(&arr.data[0][0], SIZE_0 * SIZE_1); mImpl->copyFromHost(&arr.data[0][0], SIZE_0 * SIZE_1);
} }
...@@ -119,9 +132,9 @@ class Tensor : public Data, ...@@ -119,9 +132,9 @@ class Tensor : public Data,
constexpr Tensor &operator=(Array2D<T, SIZE_0, SIZE_1> &&arr) { constexpr Tensor &operator=(Array2D<T, SIZE_0, SIZE_1> &&arr) {
resize({SIZE_0, SIZE_1}); resize({SIZE_0, SIZE_1});
if (!mImpl) { if (!mImpl) {
mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this); mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0 * SIZE_1);
} }
mImpl->copyFromHost(&arr.data[0][0], SIZE_0 * SIZE_1); mImpl->copyFromHost(&arr.data[0][0], SIZE_0 * SIZE_1, mImplOffset);
return *this; return *this;
} }
...@@ -137,7 +150,8 @@ class Tensor : public Data, ...@@ -137,7 +150,8 @@ class Tensor : public Data,
: Data(Type), : Data(Type),
mDataType(NativeType<T>::type), mDataType(NativeType<T>::type),
mDims({SIZE_0, SIZE_1, SIZE_2}), mDims({SIZE_0, SIZE_1, SIZE_2}),
mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)), mStrides({SIZE_1 * SIZE_2, SIZE_2, 1}),
mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0 * SIZE_1 * SIZE_2)),
mSize(SIZE_0 * SIZE_1 * SIZE_2) { mSize(SIZE_0 * SIZE_1 * SIZE_2) {
mImpl->copyFromHost(&arr.data[0][0][0], SIZE_0 * SIZE_1 * SIZE_2); mImpl->copyFromHost(&arr.data[0][0][0], SIZE_0 * SIZE_1 * SIZE_2);
} }
...@@ -146,9 +160,9 @@ class Tensor : public Data, ...@@ -146,9 +160,9 @@ class Tensor : public Data,
constexpr Tensor &operator=(Array3D<T, SIZE_0, SIZE_1, SIZE_2> &&arr) { constexpr Tensor &operator=(Array3D<T, SIZE_0, SIZE_1, SIZE_2> &&arr) {
resize({SIZE_0, SIZE_1, SIZE_2}); resize({SIZE_0, SIZE_1, SIZE_2});
if (!mImpl) { if (!mImpl) {
mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this); mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0 * SIZE_1 * SIZE_2);
} }
mImpl->copyFromHost(&arr.data[0][0][0], SIZE_0 * SIZE_1 * SIZE_2); mImpl->copyFromHost(&arr.data[0][0][0], SIZE_0 * SIZE_1 * SIZE_2, mImplOffset);
return *this; return *this;
} }
...@@ -165,7 +179,8 @@ class Tensor : public Data, ...@@ -165,7 +179,8 @@ class Tensor : public Data,
: Data(Type), : Data(Type),
mDataType(NativeType<T>::type), mDataType(NativeType<T>::type),
mDims({SIZE_0, SIZE_1, SIZE_2, SIZE_3}), mDims({SIZE_0, SIZE_1, SIZE_2, SIZE_3}),
mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)), mStrides({SIZE_1 * SIZE_2 * SIZE_3, SIZE_2 * SIZE_3, SIZE_3, 1}),
mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3)),
mSize(SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3) { mSize(SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3) {
mImpl->copyFromHost(&arr.data[0][0][0][0], SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3); mImpl->copyFromHost(&arr.data[0][0][0][0], SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3);
} }
...@@ -174,33 +189,35 @@ class Tensor : public Data, ...@@ -174,33 +189,35 @@ class Tensor : public Data,
constexpr Tensor &operator=(Array4D<T, SIZE_0, SIZE_1, SIZE_2, SIZE_3> &&arr) { constexpr Tensor &operator=(Array4D<T, SIZE_0, SIZE_1, SIZE_2, SIZE_3> &&arr) {
resize({SIZE_0, SIZE_1, SIZE_2, SIZE_3}); resize({SIZE_0, SIZE_1, SIZE_2, SIZE_3});
if (!mImpl) { if (!mImpl) {
mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this); mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3);
} }
mImpl->copyFromHost(&arr.data[0][0][0][0], SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3); mImpl->copyFromHost(&arr.data[0][0][0][0], SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3, mImplOffset);
return *this; return *this;
} }
/** /**
* @brief Copy dimensions, datatype and data of another Tensor. * @brief Copy dimensions, datatype and data from another Tensor.
* If current Tensor already has an implementation, data is copied to the
* existing implementation. Tensor backend/device remain untouched.
* If current Tensor does not have an implementation, only a shallow copy
* is performed and the Tensor will share data with t.
* @param t other Tensor object. * @param t other Tensor object.
* @return Tensor& * @return Tensor&
*/ */
Tensor &operator=(const Tensor &t) { Tensor &operator=(const Tensor &t) {
resize(t.dims()); resize(t.dims(), t.strides());
setDataType(t.dataType()); setDataType(t.dataType(), false); // do not convert existing data
if (t.hasImpl()) { if (t.hasImpl()) {
if (hasImpl()) { if (hasImpl()) {
copyCastFrom(t); copyFrom(t);
} }
else { else {
mImpl = Registrar<Tensor>::create({t.mImpl->backend(), dataType()})(*this); // Perform a shallow copy only
mImpl->setDevice(t.mImpl->device().second); setImpl(t.mImpl, t.mImplOffset);
// Same backend, same device => directly use copy()
mImpl->copy(t.mImpl->rawPtr(), mSize);
} }
} }
else { else {
mImpl = nullptr; setImpl(nullptr);
} }
return *this; return *this;
} }
...@@ -233,17 +250,15 @@ class Tensor : public Data, ...@@ -233,17 +250,15 @@ class Tensor : public Data,
if (mImpl->device() != std::make_pair(name, device)) { if (mImpl->device() != std::make_pair(name, device)) {
// Backend change: create new impl, copy from old to new and replace // Backend change: create new impl, copy from old to new and replace
// impl // impl
std::unique_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({name, mDataType})(*this); std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({name, mDataType})(device, mImpl->size());
newImpl->setDevice(device);
if (copyFrom) { if (copyFrom) {
newImpl->copyFrom(*mImpl, size()); newImpl->copyFrom(*mImpl, mImpl->size(), mImplOffset, 0);
} }
mImpl = std::move(newImpl); setImpl(newImpl);
} }
} }
else { else {
mImpl = Registrar<Tensor>::create({name, mDataType})(*this); mImpl = Registrar<Tensor>::create({name, mDataType})(device, mSize);
mImpl->setDevice(device);
} }
} }
...@@ -273,21 +288,32 @@ class Tensor : public Data, ...@@ -273,21 +288,32 @@ class Tensor : public Data,
*/ */
void setDataType(const DataType dt, bool copyCast = true) { void setDataType(const DataType dt, bool copyCast = true) {
if (mImpl && (dataType() != dt)) { if (mImpl && (dataType() != dt)) {
std::unique_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), dt})(*this); std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), dt})(mImpl->device().second, mImpl->size());
if (copyCast) { if (copyCast) {
newImpl->copyCast(mImpl->rawPtr(), size(), mDataType); newImpl->copyCast(mImpl->rawPtr(mImplOffset), mDataType, mImpl->size());
} }
mImpl = std::move(newImpl); setImpl(newImpl);
} }
mDataType = dt; mDataType = dt;
} }
/** /**
* @brief Get the Impl object * @brief Get the Impl object
* @return constexpr const std::unique_ptr<TensorImpl>& * @return constexpr const std::shared_ptr<TensorImpl>&
*/ */
constexpr const std::unique_ptr<TensorImpl> &getImpl() { return mImpl; } constexpr const std::shared_ptr<TensorImpl> &getImpl() const { return mImpl; }
constexpr const std::unique_ptr<TensorImpl> &getImpl() const { return mImpl; } constexpr std::size_t getImplOffset() const { return mImplOffset; }
/**
* @brief Set the Impl object
*
* @param impl New impl shared pointer
* @param implOffset Storage offset in this new impl for this Tensor
*/
void setImpl(std::shared_ptr<TensorImpl> impl, std::size_t implOffset = 0) {
mImpl = impl;
mImplOffset = implOffset;
}
/** /**
* @brief Return if an implementaiton has been associated. * @brief Return if an implementaiton has been associated.
...@@ -319,6 +345,18 @@ class Tensor : public Data, ...@@ -319,6 +345,18 @@ class Tensor : public Data,
*/ */
constexpr const std::vector<DimSize_t> &dims() const { return mDims; } constexpr const std::vector<DimSize_t> &dims() const { return mDims; }
/**
* @brief Get strides of the Tensor object.
* @return constexpr const std::vector<DimSize_t>&
*/
constexpr const std::vector<DimSize_t> &strides() const { return mStrides; }
/**
* @brief Return true if Tensor is contiguous in memory.
* @return bool
*/
constexpr bool isContiguous() const { return mContiguous; }
/** /**
* @brief Get the number of elements in the Tensor object. * @brief Get the number of elements in the Tensor object.
* @return constexpr std::size_t * @return constexpr std::size_t
...@@ -350,10 +388,49 @@ class Tensor : public Data, ...@@ -350,10 +388,49 @@ class Tensor : public Data,
* one, all previous data is invalided. Otherwise, previous data may or may * one, all previous data is invalided. Otherwise, previous data may or may
* not remain valid, depending on the backend implementation. * not remain valid, depending on the backend implementation.
* @param dims New dimensions * @param dims New dimensions
* @param strides Stride of the tensor (if not specified, "nested" stride is used)
*/ */
void resize(const std::vector<DimSize_t> &dims) { void resize(const std::vector<DimSize_t> &dims, std::vector<DimSize_t> strides = std::vector<DimSize_t>()) {
mDims = dims; bool checkContiguous = true;
computeSize(); if (strides.empty()) {
strides.resize(dims.size());
size_t expectedStride = 1;
for (int dim = dims.size() - 1; dim >= 0; --dim) {
strides[dim] = expectedStride;
expectedStride*= dims[dim];
}
checkContiguous = false;
}
else {
AIDGE_ASSERT(strides.size() == dims.size(), "Number of strides must match number of dims");
}
if (mImpl.use_count() > 1) {
// Here we could also create a new storage for this tensor in this case
// But, is it more likely that the user really wants this, or that he did a mistake?
AIDGE_ASSERT(dims == mDims && strides == mStrides, "Cannot resize Tensor with shared storage");
}
else {
mDims = dims;
mStrides = strides;
mContiguous = true;
if (checkContiguous) {
size_t expectedStride = 1;
for (int dim = dims.size() - 1; dim >= 0; --dim) {
if (strides[dim] != expectedStride) {
mContiguous = false;
break;
}
expectedStride*= dims[dim];
}
}
computeSize();
if (mImpl) {
mImpl->resize(mSize);
}
}
} }
/** /**
...@@ -367,25 +444,25 @@ class Tensor : public Data, ...@@ -367,25 +444,25 @@ class Tensor : public Data,
const expectedType& get(std::size_t idx) const { const expectedType& get(std::size_t idx) const {
AIDGE_ASSERT(NativeType<expectedType>::type == mDataType, "wrong data type"); AIDGE_ASSERT(NativeType<expectedType>::type == mDataType, "wrong data type");
AIDGE_ASSERT(idx < mSize, "idx out of range"); AIDGE_ASSERT(idx < mSize, "idx out of range");
return *reinterpret_cast<expectedType *>(mImpl->hostPtr(idx)); return *reinterpret_cast<expectedType *>(mImpl->hostPtr(mImplOffset + idx));
} }
template <typename expectedType> template <typename expectedType>
const expectedType& get(std::vector<std::size_t> coordIdx) const { const expectedType& get(std::vector<std::size_t> coordIdx) const {
return get<expectedType>(getIdx(coordIdx)); return get<expectedType>(getStorageIdx(coordIdx));
} }
template <typename expectedType> template <typename expectedType>
void set(std::size_t idx, expectedType value){ void set(std::size_t idx, expectedType value){
AIDGE_ASSERT(NativeType<expectedType>::type == mDataType, "wrong data type"); AIDGE_ASSERT(NativeType<expectedType>::type == mDataType, "wrong data type");
AIDGE_ASSERT(idx < mSize, "idx out of range"); AIDGE_ASSERT(idx < mSize, "idx out of range");
expectedType* dataPtr = static_cast<expectedType*>(mImpl->hostPtr(idx)); expectedType* dataPtr = static_cast<expectedType*>(mImpl->hostPtr(mImplOffset + idx));
*dataPtr = value; *dataPtr = value;
} }
template <typename expectedType> template <typename expectedType>
void set(std::vector<std::size_t> coordIdx, expectedType value){ void set(std::vector<std::size_t> coordIdx, expectedType value){
set<expectedType>(getIdx(coordIdx), value); set<expectedType>(getStorageIdx(coordIdx), value);
} }
...@@ -449,9 +526,9 @@ class Tensor : public Data, ...@@ -449,9 +526,9 @@ class Tensor : public Data,
for (; dimVals[dim] < static_cast<std::size_t>(dims()[dim]); ++dimVals[dim]) { for (; dimVals[dim] < static_cast<std::size_t>(dims()[dim]); ++dimVals[dim]) {
res += spaceString + "{"; res += spaceString + "{";
for (DimSize_t j = 0; j < dims()[dim + 1] - 1; ++j) { for (DimSize_t j = 0; j < dims()[dim + 1] - 1; ++j) {
res += " " + ptrToString(mDataType, mImpl->hostPtr(), counter++) + ","; res += " " + ptrToString(mDataType, mImpl->hostPtr(mImplOffset), counter++) + ",";
} }
res += " " + ptrToString(mDataType, mImpl->hostPtr(), counter++) + "}"; res += " " + ptrToString(mDataType, mImpl->hostPtr(mImplOffset), counter++) + "}";
if (dimVals[dim] < static_cast<std::size_t>(dims()[dim] - 1)) { if (dimVals[dim] < static_cast<std::size_t>(dims()[dim] - 1)) {
res += ","; res += ",";
} }
...@@ -471,7 +548,7 @@ class Tensor : public Data, ...@@ -471,7 +548,7 @@ class Tensor : public Data,
} else { } else {
res += "{"; res += "{";
for (DimSize_t j = 0; j < dims()[0]; ++j) { for (DimSize_t j = 0; j < dims()[0]; ++j) {
res += " " + ptrToString(mDataType, mImpl->hostPtr(), j) + ((j < dims()[0]-1) ? "," : " "); res += " " + ptrToString(mDataType, mImpl->hostPtr(mImplOffset), j) + ((j < dims()[0]-1) ? "," : " ");
} }
} }
res += "}"; res += "}";
...@@ -493,6 +570,7 @@ class Tensor : public Data, ...@@ -493,6 +570,7 @@ class Tensor : public Data,
/** /**
* @brief From the the 1D contiguous index, return the coordinate of an element in the tensor. * @brief From the the 1D contiguous index, return the coordinate of an element in the tensor.
* Beware: do not use this function with the storage index!
* *
* @param flatIdx 1D contiguous index of the value considering a flatten, contiguous, tensor. * @param flatIdx 1D contiguous index of the value considering a flatten, contiguous, tensor.
* @return std::vector<DimSize_t> * @return std::vector<DimSize_t>
...@@ -512,6 +590,8 @@ class Tensor : public Data, ...@@ -512,6 +590,8 @@ class Tensor : public Data,
* @brief From the coordinate returns the 1D contiguous index of an element in the tensor. * @brief From the coordinate returns the 1D contiguous index of an element in the tensor.
* If the number of coordinates is inferior to the number of dimensions, * If the number of coordinates is inferior to the number of dimensions,
* the remaining coordinates are assumed to be 0. * the remaining coordinates are assumed to be 0.
* Beware: the contiguous index will only correspond to the storage index
* if the tensor is contiguous!
* *
* @param coordIdx Coordinate to an element in the tensor * @param coordIdx Coordinate to an element in the tensor
* @return DimSize_t Contiguous index * @return DimSize_t Contiguous index
...@@ -527,6 +607,51 @@ class Tensor : public Data, ...@@ -527,6 +607,51 @@ class Tensor : public Data,
return flatIdx + coordIdx[i]; return flatIdx + coordIdx[i];
} }
/**
* @brief From the coordinate returns the 1D storage index of an element in the tensor.
* If the number of coordinates is inferior to the number of dimensions,
* the remaining coordinates are assumed to be 0.
*
* @param coordIdx Coordinate to an element in the tensor
* @return DimSize_t Storage index
*/
std::size_t getStorageIdx(const std::vector<std::size_t>& coordIdx) const {
AIDGE_ASSERT(coordIdx.size() <= mDims.size(), "Coordinates does not match number of dimensions");
return std::inner_product(coordIdx.begin(), coordIdx.end(), mStrides.begin(), DimSize_t(0));
}
/**
* Returns a sub-tensor with one or more dimension less.
* For instance, t.extract({1}) on a CHW tensor will return the HW tensor
* of channel #1.
* Likewise, t.extract({0, 1}) on a NCHW tensor will return the HW tensor
* of batch #0 and channel #1.
* No memory copy is performed, the returned tensor does not own the memory.
* If the number of coordinates matches the number of dimensions, an empty
* tensor is returned.
* It current tensor was contiguous, the returned tensor is garanteed to be
* contiguous as well.
*
* @param coordIdx Coordinates of the sub-tensor to extract
* @return Tensor Sub-tensor.
*/
Tensor extract(const std::vector<std::size_t>& coordIdx) const;
/**
* Returns a sub-tensor at some coordinate and with some dimension.
*
* @param coordIdx First coordinates of the sub-tensor to extract
* @param dims Dimensions of the sub-tensor to extract
* @return Tensor Sub-tensor.
*/
Tensor extract(const std::vector<std::size_t>& coordIdx, const std::vector<std::size_t>& dims) const;
/**
* Make the tensor's storage contiguous, if it is not already the case.
* If not contiguous, a new memory space is allocated.
*/
void makeContiguous();
/** /**
* Copy-cast data from a Tensor on the same device. * Copy-cast data from a Tensor on the same device.
* If current tensor backend/device is set and is different from src, an * If current tensor backend/device is set and is different from src, an
...@@ -572,6 +697,20 @@ class Tensor : public Data, ...@@ -572,6 +697,20 @@ class Tensor : public Data,
copyCastFrom(src, movedSrc); copyCastFrom(src, movedSrc);
} }
/**
* Return a reference to a Tensor that is garanteed to be contiguous:
* - itself, if already contiguous;
* - the provided Tensor, overwritten with the copied data.
* The data type, backend and device stay the same.
* @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
* The shared_ptr does not need to be initialized. No new memory allocation
* will occur if fallback has already been allocated with the right
* type/size/device.
* @return Reference to either itself or to fallback.
*/
Tensor& refContiguous(std::shared_ptr<Tensor>& fallback);
const Tensor& refContiguous(std::shared_ptr<Tensor>& fallback) const;
/** /**
* Return a reference to a Tensor casted to the desired data type: * Return a reference to a Tensor casted to the desired data type:
* - itself, if already at the right data type; * - itself, if already at the right data type;
...@@ -642,6 +781,43 @@ class Tensor : public Data, ...@@ -642,6 +781,43 @@ class Tensor : public Data,
return refCastFrom(fallback, targetReqs.dataType(), device.first, device.second); return refCastFrom(fallback, targetReqs.dataType(), device.first, device.second);
} }
/**
* Return a reference to a Tensor on desired data type and backend/device:
* - itself, if already with the right characteristics;
* - the provided Tensor, overwritten with the right characteristics.
* NOTE: no data is copy-casted. If it was so in a previous refCastFrom() on
* the same fallback, it remains valid, otherwise, data is invalid.
* @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
* The shared_ptr does not need to be initialized. No new memory allocation
* will occur if fallback has already been allocated with the right
* type/size/device.
* @param dt The desired data type.
* @param backend The desired backend.
* @param device The desired device.
* @return Reference to either itself or to fallback.
*/
Tensor& ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device = 0);
const Tensor& ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device = 0) const;
/**
* Return a reference to a Tensor with same characteristics
* (data type, backend/device) as targetReqs Tensor:
* - itself, if already with the right characteristics;
* - the provided Tensor, overwritten with the right characteristics.
* NOTE: no data is copy-casted. If it was so in a previous refCastFrom() on
* the same fallback, it remains valid, otherwise, data is invalid.
* @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
* The shared_ptr does not need to be initialized. No new memory allocation
* will occur if fallback has already been allocated with the right
* type/size/device.
* @param targetReqs Tensor with the desired target characteristics.
* @return Reference to either itself or to fallback.
*/
Tensor& ref(std::shared_ptr<Tensor>& fallback, const Tensor& targetReqs) {
const auto& device = targetReqs.getImpl()->device();
return ref(fallback, targetReqs.dataType(), device.first, device.second);
}
private: private:
///\bug not protected against overflow ///\bug not protected against overflow
void computeSize() { void computeSize() {
......
...@@ -30,7 +30,7 @@ void addCtor(py::class_<Tensor, ...@@ -30,7 +30,7 @@ void addCtor(py::class_<Tensor,
Data, Data,
Registrable<Tensor, Registrable<Tensor,
std::tuple<std::string, DataType>, std::tuple<std::string, DataType>,
std::unique_ptr<TensorImpl>(const Tensor&)>>& mTensor){ std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>>& mTensor){
mTensor.def(py::init([]( py::array_t<T, py::array::c_style | py::array::forcecast> b) { mTensor.def(py::init([]( py::array_t<T, py::array::c_style | py::array::forcecast> b) {
/* Request a buffer descriptor from Python */ /* Request a buffer descriptor from Python */
py::buffer_info info = b.request(); py::buffer_info info = b.request();
...@@ -58,16 +58,16 @@ void addCtor(py::class_<Tensor, ...@@ -58,16 +58,16 @@ void addCtor(py::class_<Tensor,
void init_Tensor(py::module& m){ void init_Tensor(py::module& m){
py::class_<Registrable<Tensor, py::class_<Registrable<Tensor,
std::tuple<std::string, DataType>, std::tuple<std::string, DataType>,
std::unique_ptr<TensorImpl>(const Tensor&)>, std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>,
std::shared_ptr<Registrable<Tensor, std::shared_ptr<Registrable<Tensor,
std::tuple<std::string, DataType>, std::tuple<std::string, DataType>,
std::unique_ptr<TensorImpl>(const Tensor&)>>>(m,"TensorRegistrable"); std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>>>(m,"TensorRegistrable");
py::class_<Tensor, std::shared_ptr<Tensor>, py::class_<Tensor, std::shared_ptr<Tensor>,
Data, Data,
Registrable<Tensor, Registrable<Tensor,
std::tuple<std::string, DataType>, std::tuple<std::string, DataType>,
std::unique_ptr<TensorImpl>(const Tensor&)>> pyClassTensor std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>> pyClassTensor
(m,"Tensor", py::multiple_inheritance(), py::buffer_protocol()); (m,"Tensor", py::multiple_inheritance(), py::buffer_protocol());
pyClassTensor.def(py::init<>()) pyClassTensor.def(py::init<>())
...@@ -76,7 +76,7 @@ void init_Tensor(py::module& m){ ...@@ -76,7 +76,7 @@ void init_Tensor(py::module& m){
.def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims) .def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims)
.def("dtype", &Tensor::dataType) .def("dtype", &Tensor::dataType)
.def("size", &Tensor::size) .def("size", &Tensor::size)
.def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&)) &Tensor::resize) .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize)
.def("has_impl", &Tensor::hasImpl) .def("has_impl", &Tensor::hasImpl)
.def("get_coord", &Tensor::getCoord) .def("get_coord", &Tensor::getCoord)
.def("get_idx", &Tensor::getIdx) .def("get_idx", &Tensor::getIdx)
...@@ -118,7 +118,7 @@ void init_Tensor(py::module& m){ ...@@ -118,7 +118,7 @@ void init_Tensor(py::module& m){
} }
}) })
.def_buffer([](Tensor& b) -> py::buffer_info { .def_buffer([](Tensor& b) -> py::buffer_info {
const std::unique_ptr<TensorImpl>& tensorImpl = b.getImpl(); const std::shared_ptr<TensorImpl>& tensorImpl = b.getImpl();
std::vector<size_t> dims; std::vector<size_t> dims;
std::vector<size_t> strides; std::vector<size_t> strides;
......
...@@ -14,23 +14,23 @@ ...@@ -14,23 +14,23 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
void Aidge::TensorImpl::copyFrom(const TensorImpl& srcImpl, NbElts_t length) { void Aidge::TensorImpl::copyFrom(const TensorImpl& srcImpl, NbElts_t length, NbElts_t srcOffset, NbElts_t dstOffset) {
if (&srcImpl == this) { if (&srcImpl == this && srcOffset == dstOffset) {
return; return;
} }
if (srcImpl.device() != device()) { if (srcImpl.device() != device()) {
if (srcImpl.backend() == backend()) { if (srcImpl.backend() == backend()) {
// Same backend, but different device // Same backend, but different device
copyFromDevice(srcImpl.rawPtr(), length, srcImpl.device()); copyFromDevice(srcImpl.rawPtr(srcOffset), srcImpl.device(), length, dstOffset);
} }
else if (srcImpl.hostPtr() != nullptr) { else if (srcImpl.hostPtr() != nullptr) {
// Different backend, but input is valid on host // Different backend, but input is valid on host
copyFromHost(srcImpl.hostPtr(), length); copyFromHost(srcImpl.hostPtr(srcOffset), length, dstOffset);
} }
else if (hostPtr() != nullptr) { else if (hostPtr() != nullptr) {
// Different backend, but dst is valid on host // Different backend, but dst is valid on host
srcImpl.copyToHost(hostPtr(), length); srcImpl.copyToHost(hostPtr(srcOffset), length, dstOffset);
} }
else { else {
// No direct link possible from src to dst device // No direct link possible from src to dst device
...@@ -40,12 +40,12 @@ void Aidge::TensorImpl::copyFrom(const TensorImpl& srcImpl, NbElts_t length) { ...@@ -40,12 +40,12 @@ void Aidge::TensorImpl::copyFrom(const TensorImpl& srcImpl, NbElts_t length) {
// - There is currently no concrete use case // - There is currently no concrete use case
// - Just providing a pointer would be unsafe (risk of buffer overflow...) // - Just providing a pointer would be unsafe (risk of buffer overflow...)
auto tmpHostBuffer = std::unique_ptr<char[]>(new char[scalarSize() * length]); auto tmpHostBuffer = std::unique_ptr<char[]>(new char[scalarSize() * length]);
srcImpl.copyToHost(tmpHostBuffer.get(), length); srcImpl.copyToHost(tmpHostBuffer.get(), length, srcOffset);
copyFromHost(tmpHostBuffer.get(), length); copyFromHost(tmpHostBuffer.get(), length, dstOffset);
} }
} }
else { else {
// Same device: simple copy on device // Same device: simple copy on device
copy(srcImpl.rawPtr(), length); copy(srcImpl.rawPtr(srcOffset), length, dstOffset);
} }
} }
...@@ -13,11 +13,72 @@ ...@@ -13,11 +13,72 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
Aidge::Tensor Aidge::Tensor::extract(const std::vector<std::size_t>& coordIdx) const {
AIDGE_ASSERT(isContiguous(), "Tensor must be contiguous");
AIDGE_ASSERT(coordIdx.size() <= mDims.size(), "Number of coordinates is higher than number of dimensions");
Tensor subTensor(mDataType);
subTensor.resize(std::vector<size_t>(mDims.begin() + coordIdx.size(), mDims.end()),
std::vector<size_t>(mStrides.begin() + coordIdx.size(), mStrides.end()));
subTensor.setBackend(mImpl->backend(), mImpl->device().second);
subTensor.setImpl(mImpl, mImplOffset + getStorageIdx(coordIdx));
return subTensor;
}
Aidge::Tensor Aidge::Tensor::extract(const std::vector<std::size_t>& coordIdx, const std::vector<std::size_t>& dims) const {
AIDGE_ASSERT(isContiguous(), "Tensor must be contiguous");
AIDGE_ASSERT(coordIdx.size() == mDims.size(), "Coordinates does not match number of dimensions");
Tensor subTensor(mDataType);
subTensor.resize(dims, mStrides);
subTensor.setBackend(mImpl->backend(), mImpl->device().second);
subTensor.setImpl(mImpl, mImplOffset + getStorageIdx(coordIdx));
return subTensor;
}
void Aidge::Tensor::makeContiguous() {
if (!mImpl || isContiguous()) {
return;
}
// Block so that mImpl ref count is 1 for resize()
{
// Create a new storage that will be contiguous
std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), mDataType})(mImpl->device().second, mSize);
// Copy elements from old to new storage
size_t idx = 0;
while (idx < mSize) {
const size_t storageIdx = getStorageIdx(getCoord(idx));
// Determine the size of the contiguous chunk
size_t copySize = 1;
while (idx + copySize < mSize &&
getStorageIdx(getCoord(idx + copySize)) == storageIdx + copySize)
{
++copySize;
}
// Perform a single copy for the contiguous chunk
newImpl->copy(mImpl->rawPtr(mImplOffset + storageIdx), copySize, idx);
// Move to the next index after the contiguous chunk
idx += copySize;
}
// Replace old storage by new, contiguous, storage
setImpl(newImpl);
}
// Resize tensor without strides => tensor is now contiguous
resize(mDims);
}
void Aidge::Tensor::copyCast(const Tensor& src) { void Aidge::Tensor::copyCast(const Tensor& src) {
if (&src == this) { if (&src == this) {
return; return;
} }
AIDGE_ASSERT(src.isContiguous(), "cannot copy-cast non-contiguous tensor");
// Current Tensor has necessarily a data type, but may not have backend // Current Tensor has necessarily a data type, but may not have backend
if (!getImpl()) { if (!getImpl()) {
// If no backend was set for the current tensor, use the same as src // If no backend was set for the current tensor, use the same as src
...@@ -27,7 +88,7 @@ void Aidge::Tensor::copyCast(const Tensor& src) { ...@@ -27,7 +88,7 @@ void Aidge::Tensor::copyCast(const Tensor& src) {
resize(src.dims()); resize(src.dims());
AIDGE_ASSERT(src.getImpl()->device() == getImpl()->device(), "cannot copy-cast from a different backend/device"); AIDGE_ASSERT(src.getImpl()->device() == getImpl()->device(), "cannot copy-cast from a different backend/device");
getImpl()->copyCast(src.getImpl()->rawPtr(), src.size(), src.dataType()); getImpl()->copyCast(src.getImpl()->rawPtr(src.mImplOffset), src.dataType(), src.size(), mImplOffset);
} }
void Aidge::Tensor::copyFrom(const Tensor& src) { void Aidge::Tensor::copyFrom(const Tensor& src) {
...@@ -35,6 +96,8 @@ void Aidge::Tensor::copyFrom(const Tensor& src) { ...@@ -35,6 +96,8 @@ void Aidge::Tensor::copyFrom(const Tensor& src) {
return; return;
} }
AIDGE_ASSERT(src.isContiguous(), "cannot copy from non-contiguous tensor");
// Current Tensor has necessarily a data type, but may not have backend // Current Tensor has necessarily a data type, but may not have backend
if (!getImpl()) { if (!getImpl()) {
// If no backend was set for the current tensor, use the same as src // If no backend was set for the current tensor, use the same as src
...@@ -44,7 +107,7 @@ void Aidge::Tensor::copyFrom(const Tensor& src) { ...@@ -44,7 +107,7 @@ void Aidge::Tensor::copyFrom(const Tensor& src) {
resize(src.dims()); resize(src.dims());
AIDGE_ASSERT(src.dataType() == dataType(), "cannot copy from a different data type"); AIDGE_ASSERT(src.dataType() == dataType(), "cannot copy from a different data type");
getImpl()->copyFrom(*(src.getImpl()), src.size()); getImpl()->copyFrom(*(src.getImpl()), src.size(), src.mImplOffset, mImplOffset);
} }
void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& movedSrcPtr) { void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& movedSrcPtr) {
...@@ -52,6 +115,8 @@ void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& mov ...@@ -52,6 +115,8 @@ void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& mov
return; return;
} }
AIDGE_ASSERT(src.isContiguous(), "cannot copy-cast from non-contiguous tensor");
// Current Tensor has necessarily a data type, but may not have backend // Current Tensor has necessarily a data type, but may not have backend
if (!getImpl()) { if (!getImpl()) {
// If no backend was set for the current tensor, use the same as src // If no backend was set for the current tensor, use the same as src
...@@ -65,12 +130,35 @@ void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& mov ...@@ -65,12 +130,35 @@ void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& mov
const auto device = getImpl()->device(); const auto device = getImpl()->device();
const Tensor& movedSrc = src.refFrom(movedSrcPtr, device.first, device.second); const Tensor& movedSrc = src.refFrom(movedSrcPtr, device.first, device.second);
// Second, copy-cast data (necessary) // Second, copy-cast data (necessary)
getImpl()->copyCast(movedSrc.getImpl()->rawPtr(), movedSrc.size(), movedSrc.dataType()); getImpl()->copyCast(movedSrc.getImpl()->rawPtr(movedSrc.mImplOffset), movedSrc.dataType(), movedSrc.size(), mImplOffset);
} }
else { else {
// Directly copy, no conversion necessary // Directly copy, no conversion necessary
// Avoid making a double copy if both data type and device are the same // Avoid making a double copy if both data type and device are the same
getImpl()->copyFrom(*(src.getImpl()), src.size()); getImpl()->copyFrom(*(src.getImpl()), src.size(), src.mImplOffset, mImplOffset);
}
}
Aidge::Tensor& Aidge::Tensor::refContiguous(std::shared_ptr<Tensor>& fallback) {
// Scott Meyers' solution to avoid code duplication
return const_cast<Tensor&>(static_cast<const Tensor&>(*this).refContiguous(fallback));
}
const Aidge::Tensor& Aidge::Tensor::refContiguous(std::shared_ptr<Tensor>& fallback) const {
AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot refCast() it");
if (isContiguous()) {
return *this;
}
else {
if (this != fallback.get()) {
// Shallow copy to fallback
*fallback = *this;
}
// Make fallback contiguous
fallback->makeContiguous();
return *fallback;
} }
} }
...@@ -91,6 +179,8 @@ const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, c ...@@ -91,6 +179,8 @@ const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, c
fallback->setDataType(dt); fallback->setDataType(dt);
} }
else { else {
AIDGE_ASSERT(isContiguous(), "cannot refCast non-contiguous tensor");
if (!fallback) { if (!fallback) {
fallback = std::make_shared<Tensor>(dt); fallback = std::make_shared<Tensor>(dt);
} }
...@@ -101,7 +191,7 @@ const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, c ...@@ -101,7 +191,7 @@ const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, c
const auto device = getImpl()->device(); const auto device = getImpl()->device();
fallback->setBackend(device.first, device.second, false); // don't keep previous data (no copy) fallback->setBackend(device.first, device.second, false); // don't keep previous data (no copy)
fallback->resize(dims()); fallback->resize(dims());
fallback->getImpl()->copyCast(getImpl()->rawPtr(), size(), dataType()); fallback->getImpl()->copyCast(getImpl()->rawPtr(mImplOffset), dataType(), size(), fallback->mImplOffset);
} }
return *fallback; return *fallback;
} }
...@@ -124,6 +214,8 @@ const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, c ...@@ -124,6 +214,8 @@ const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, c
fallback->setBackend(backend, device); fallback->setBackend(backend, device);
} }
else { else {
AIDGE_ASSERT(isContiguous(), "cannot refFrom non-contiguous tensor");
if (!fallback) { if (!fallback) {
fallback = std::make_shared<Tensor>(dataType()); fallback = std::make_shared<Tensor>(dataType());
} }
...@@ -133,8 +225,34 @@ const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, c ...@@ -133,8 +225,34 @@ const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, c
fallback->setBackend(backend, device, false); // don't keep previous data (no copy) fallback->setBackend(backend, device, false); // don't keep previous data (no copy)
fallback->resize(dims()); fallback->resize(dims());
fallback->getImpl()->copyFrom(*getImpl(), size()); fallback->getImpl()->copyFrom(*getImpl(), size(), mImplOffset, fallback->mImplOffset);
}
return *fallback;
}
}
Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device) {
// Scott Meyers' solution to avoid code duplication
return const_cast<Tensor&>(static_cast<const Tensor&>(*this).ref(fallback, dt, backend, device));
}
const Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device) const {
AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot ref() it");
if (dt == dataType() && std::make_pair(backend, device) == getImpl()->device()) {
return *this;
}
else {
// Change fallback type, backend & device, without any data copy
if (!fallback) {
fallback = std::make_shared<Tensor>(dt);
} }
else {
fallback->setDataType(dt, false); // don't keep previous data (no copy)
}
fallback->setBackend(backend, device, false); // don't keep previous data (no copy)
fallback->resize(dims());
return *fallback; return *fallback;
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment