Skip to content
Snippets Groups Projects
Commit 9cb56340 authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge remote-tracking branch 'origin/dev' into dataloader

parents f18032c9 1bd36647
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!4Dataloader
...@@ -67,19 +67,13 @@ private: ...@@ -67,19 +67,13 @@ private:
class TensorImpl { class TensorImpl {
public: public:
TensorImpl() = delete; TensorImpl() = delete;
TensorImpl(const char *backend, DeviceIdx_t device = 0) : mBackend(backend), mDevice(device){}; TensorImpl(const char *backend, DeviceIdx_t device, NbElts_t length) : mBackend(backend), mDevice(device), mNbElts(length) {};
/** /**
* Return the (backend, device) pair for this implementation. * Return the (backend, device) pair for this implementation.
*/ */
std::pair<std::string, DeviceIdx_t> device() const { return std::make_pair(mBackend, mDevice); } std::pair<std::string, DeviceIdx_t> device() const { return std::make_pair(mBackend, mDevice); }
/**
* Set the device ID for current backend.
* @param device New device ID on current backend.
*/
virtual void setDevice(DeviceIdx_t device) = 0;
/** /**
* Copy data from the same device. * Copy data from the same device.
* @param src Pointer on current implementation device. * @param src Pointer on current implementation device.
...@@ -93,30 +87,34 @@ public: ...@@ -93,30 +87,34 @@ public:
* @param srcDt Source data type. * @param srcDt Source data type.
* @param src Pointer on current implementation device. * @param src Pointer on current implementation device.
* @param length Number of elements to copy. * @param length Number of elements to copy.
* @param offset Destination offset (in number of elements).
*/ */
virtual void copyCast(const void *src, NbElts_t length, const DataType srcDt) = 0; virtual void copyCast(const void *src, const DataType srcDt, NbElts_t length, NbElts_t offset = 0) = 0;
/** /**
* Copy data from an other device on the same backend. * Copy data from an other device on the same backend.
* @param device (backend, device) pair to copy from. The backend must match current implementation backend. * @param device (backend, device) pair to copy from. The backend must match current implementation backend.
* @param src Pointer on current implementation backend. * @param src Pointer on current implementation backend.
* @param length Number of elements to copy. * @param length Number of elements to copy.
* @param offset Destination offset (in number of elements).
*/ */
virtual void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, DeviceIdx_t>& device) = 0; virtual void copyFromDevice(const void *src, const std::pair<std::string, DeviceIdx_t>& device, NbElts_t length, NbElts_t offset = 0) = 0;
/** /**
* Copy data from host. * Copy data from host.
* @param src Host pointer to copy from. * @param src Host pointer to copy from.
* @param length Number of elements to copy. * @param length Number of elements to copy.
* @param offset Destination offset (in number of elements).
*/ */
virtual void copyFromHost(const void *src, NbElts_t length) = 0; virtual void copyFromHost(const void *src, NbElts_t length, NbElts_t offset = 0) = 0;
/** /**
* Copy data to host. * Copy data to host.
* @param src Host pointer to copy to. * @param src Host pointer to copy to.
* @param length Number of elements to copy. * @param length Number of elements to copy.
* @param offset Source offset (in number of elements).
*/ */
virtual void copyToHost(void *dst, NbElts_t length) const = 0; virtual void copyToHost(void *dst, NbElts_t length, NbElts_t offset = 0) const = 0;
/** /**
* Return the raw device pointer. * Return the raw device pointer.
...@@ -146,8 +144,22 @@ public: ...@@ -146,8 +144,22 @@ public:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot set raw pointer for backend %s", mBackend); AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot set raw pointer for backend %s", mBackend);
}; };
virtual std::size_t size() const = 0; // Storage size /**
virtual std::size_t scalarSize() const = 0; // Size of one scalar (in bytes) * Set the size, in number of elements, that must be stored.
*/
void resize(NbElts_t length) {
mNbElts = length;
}
/**
* Return the number of elements stored.
*/
inline std::size_t size() const noexcept { return mNbElts; }
/**
* Return the size (in bytes) of one element (scalar).
*/
virtual std::size_t scalarSize() const noexcept = 0;
constexpr const char *backend() const { return mBackend; } constexpr const char *backend() const { return mBackend; }
virtual ~TensorImpl() = default; virtual ~TensorImpl() = default;
virtual bool operator==(const TensorImpl &othImpl) const = 0; virtual bool operator==(const TensorImpl &othImpl) const = 0;
...@@ -156,12 +168,16 @@ public: ...@@ -156,12 +168,16 @@ public:
* Copy from another backend. * Copy from another backend.
* @param srcImpl Source TensorImpl to copy from. * @param srcImpl Source TensorImpl to copy from.
* @param length Number of elements of size scalarSize() to copy * @param length Number of elements of size scalarSize() to copy
* @param srcOffset Source offset (in number of elements).
* @param dstOffset Destination offset (in number of elements).
*/ */
void copyFrom(const TensorImpl& srcImpl, NbElts_t length); void copyFrom(const TensorImpl& srcImpl, NbElts_t length, NbElts_t srcOffset = 0, NbElts_t dstOffset = 0);
protected: protected:
const char *mBackend; const char *mBackend;
DeviceIdx_t mDevice; const DeviceIdx_t mDevice;
/// Number of elements (to be) stored
NbElts_t mNbElts;
}; };
} // namespace Aidge } // namespace Aidge
......
...@@ -32,15 +32,18 @@ namespace Aidge { ...@@ -32,15 +32,18 @@ namespace Aidge {
* Contains a pointer to an actual contiguous implementation of data. * Contains a pointer to an actual contiguous implementation of data.
*/ */
class Tensor : public Data, class Tensor : public Data,
public Registrable<Tensor, std::tuple<std::string, DataType>, std::unique_ptr<TensorImpl>(const Tensor &)> { public Registrable<Tensor, std::tuple<std::string, DataType>, std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)> {
private: private:
DataType mDataType; /** enum to specify data type. */ DataType mDataType; /** enum to specify data type. */
std::vector<DimSize_t> mDims; /** Dimensions of the tensor. */ std::vector<DimSize_t> mDims; /** Dimensions of the tensor. */
std::unique_ptr<TensorImpl> mImpl; /** Pointer to the actual data implementation. */ std::vector<DimSize_t> mStrides; /** Stride dimensions of the tensor. */
std::shared_ptr<TensorImpl> mImpl; /** Pointer to the actual data implementation. */
std::size_t mImplOffset = 0;
std::shared_ptr<Tensor> mGrad; /** Pointer to the associated gradient Tensor instance. */ std::shared_ptr<Tensor> mGrad; /** Pointer to the associated gradient Tensor instance. */
// Cached data // Cached data
std::size_t mSize = 0; /** Number of elements in the Tensor. */ std::size_t mSize = 0; /** Number of elements in the Tensor. */
bool mContiguous = true;
public: public:
static constexpr const char *Type = "Tensor"; static constexpr const char *Type = "Tensor";
...@@ -71,21 +74,29 @@ class Tensor : public Data, ...@@ -71,21 +74,29 @@ class Tensor : public Data,
} }
/** /**
* @brief Construct a new Tensor object copied from another one. * @brief Construct a new Tensor object from another one (shallow copy).
* Data memory is not copied, but shared between the new Tensor and the
* initial one.
*
* @param otherTensor * @param otherTensor
*/ */
Tensor(const Tensor& otherTensor) Tensor(const Tensor&) = default;
: Data(Type), Tensor(Tensor&&) = default;
mDataType(otherTensor.mDataType),
mDims(otherTensor.mDims), /**
mSize(otherTensor.mSize) * Perform a deep copy of the tensor.
{ */
if (otherTensor.hasImpl()) { Tensor clone() const {
mImpl = Registrar<Tensor>::create({otherTensor.mImpl->backend(), dataType()})(*this); Tensor newTensor(*this);
mImpl->setDevice(otherTensor.mImpl->device().second); if (!newTensor.isContiguous()) {
// Same backend, same device => directly use copy() newTensor.makeContiguous();
mImpl->copy(otherTensor.mImpl->rawPtr(), mSize);
} }
else {
std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), mDataType})(mImpl->device().second, mSize);
newImpl->copy(mImpl->rawPtr(mImplOffset), mSize);
newTensor.setImpl(newImpl);
}
return newTensor;
} }
/** /**
...@@ -98,7 +109,8 @@ class Tensor : public Data, ...@@ -98,7 +109,8 @@ class Tensor : public Data,
: Data(Type), : Data(Type),
mDataType(NativeType<T>::type), mDataType(NativeType<T>::type),
mDims({SIZE_0}), mDims({SIZE_0}),
mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)), mStrides({1}),
mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0)),
mSize(SIZE_0) { mSize(SIZE_0) {
mImpl->copyFromHost(&arr.data[0], SIZE_0); mImpl->copyFromHost(&arr.data[0], SIZE_0);
} }
...@@ -107,9 +119,9 @@ class Tensor : public Data, ...@@ -107,9 +119,9 @@ class Tensor : public Data,
constexpr Tensor &operator=(Array1D<T, SIZE_0> &&arr) { constexpr Tensor &operator=(Array1D<T, SIZE_0> &&arr) {
resize({SIZE_0}); resize({SIZE_0});
if (!mImpl) { if (!mImpl) {
mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this); mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0);
} }
mImpl->copyFromHost(&arr.data[0], SIZE_0); mImpl->copyFromHost(&arr.data[0], SIZE_0, mImplOffset);
return *this; return *this;
} }
...@@ -124,7 +136,8 @@ class Tensor : public Data, ...@@ -124,7 +136,8 @@ class Tensor : public Data,
: Data(Type), : Data(Type),
mDataType(NativeType<T>::type), mDataType(NativeType<T>::type),
mDims({SIZE_0, SIZE_1}), mDims({SIZE_0, SIZE_1}),
mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)), mStrides({SIZE_1, 1}),
mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0 * SIZE_1)),
mSize(SIZE_0 * SIZE_1) { mSize(SIZE_0 * SIZE_1) {
mImpl->copyFromHost(&arr.data[0][0], SIZE_0 * SIZE_1); mImpl->copyFromHost(&arr.data[0][0], SIZE_0 * SIZE_1);
} }
...@@ -133,9 +146,9 @@ class Tensor : public Data, ...@@ -133,9 +146,9 @@ class Tensor : public Data,
constexpr Tensor &operator=(Array2D<T, SIZE_0, SIZE_1> &&arr) { constexpr Tensor &operator=(Array2D<T, SIZE_0, SIZE_1> &&arr) {
resize({SIZE_0, SIZE_1}); resize({SIZE_0, SIZE_1});
if (!mImpl) { if (!mImpl) {
mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this); mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0 * SIZE_1);
} }
mImpl->copyFromHost(&arr.data[0][0], SIZE_0 * SIZE_1); mImpl->copyFromHost(&arr.data[0][0], SIZE_0 * SIZE_1, mImplOffset);
return *this; return *this;
} }
...@@ -151,7 +164,8 @@ class Tensor : public Data, ...@@ -151,7 +164,8 @@ class Tensor : public Data,
: Data(Type), : Data(Type),
mDataType(NativeType<T>::type), mDataType(NativeType<T>::type),
mDims({SIZE_0, SIZE_1, SIZE_2}), mDims({SIZE_0, SIZE_1, SIZE_2}),
mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)), mStrides({SIZE_1 * SIZE_2, SIZE_2, 1}),
mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0 * SIZE_1 * SIZE_2)),
mSize(SIZE_0 * SIZE_1 * SIZE_2) { mSize(SIZE_0 * SIZE_1 * SIZE_2) {
mImpl->copyFromHost(&arr.data[0][0][0], SIZE_0 * SIZE_1 * SIZE_2); mImpl->copyFromHost(&arr.data[0][0][0], SIZE_0 * SIZE_1 * SIZE_2);
} }
...@@ -160,9 +174,9 @@ class Tensor : public Data, ...@@ -160,9 +174,9 @@ class Tensor : public Data,
constexpr Tensor &operator=(Array3D<T, SIZE_0, SIZE_1, SIZE_2> &&arr) { constexpr Tensor &operator=(Array3D<T, SIZE_0, SIZE_1, SIZE_2> &&arr) {
resize({SIZE_0, SIZE_1, SIZE_2}); resize({SIZE_0, SIZE_1, SIZE_2});
if (!mImpl) { if (!mImpl) {
mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this); mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0 * SIZE_1 * SIZE_2);
} }
mImpl->copyFromHost(&arr.data[0][0][0], SIZE_0 * SIZE_1 * SIZE_2); mImpl->copyFromHost(&arr.data[0][0][0], SIZE_0 * SIZE_1 * SIZE_2, mImplOffset);
return *this; return *this;
} }
...@@ -179,7 +193,8 @@ class Tensor : public Data, ...@@ -179,7 +193,8 @@ class Tensor : public Data,
: Data(Type), : Data(Type),
mDataType(NativeType<T>::type), mDataType(NativeType<T>::type),
mDims({SIZE_0, SIZE_1, SIZE_2, SIZE_3}), mDims({SIZE_0, SIZE_1, SIZE_2, SIZE_3}),
mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)), mStrides({SIZE_1 * SIZE_2 * SIZE_3, SIZE_2 * SIZE_3, SIZE_3, 1}),
mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3)),
mSize(SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3) { mSize(SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3) {
mImpl->copyFromHost(&arr.data[0][0][0][0], SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3); mImpl->copyFromHost(&arr.data[0][0][0][0], SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3);
} }
...@@ -188,33 +203,35 @@ class Tensor : public Data, ...@@ -188,33 +203,35 @@ class Tensor : public Data,
constexpr Tensor &operator=(Array4D<T, SIZE_0, SIZE_1, SIZE_2, SIZE_3> &&arr) { constexpr Tensor &operator=(Array4D<T, SIZE_0, SIZE_1, SIZE_2, SIZE_3> &&arr) {
resize({SIZE_0, SIZE_1, SIZE_2, SIZE_3}); resize({SIZE_0, SIZE_1, SIZE_2, SIZE_3});
if (!mImpl) { if (!mImpl) {
mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this); mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3);
} }
mImpl->copyFromHost(&arr.data[0][0][0][0], SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3); mImpl->copyFromHost(&arr.data[0][0][0][0], SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3, mImplOffset);
return *this; return *this;
} }
/** /**
* @brief Copy dimensions, datatype and data of another Tensor. * @brief Copy dimensions, datatype and data from another Tensor.
* If current Tensor already has an implementation, data is copied to the
* existing implementation. Tensor backend/device remain untouched.
* If current Tensor does not have an implementation, only a shallow copy
* is performed and the Tensor will share data with t.
* @param t other Tensor object. * @param t other Tensor object.
* @return Tensor& * @return Tensor&
*/ */
Tensor &operator=(const Tensor &t) { Tensor &operator=(const Tensor &t) {
resize(t.dims()); resize(t.dims(), t.strides());
setDataType(t.dataType()); setDataType(t.dataType(), false); // do not convert existing data
if (t.hasImpl()) { if (t.hasImpl()) {
if (hasImpl()) { if (hasImpl()) {
copyCastFrom(t); copyFrom(t);
} }
else { else {
mImpl = Registrar<Tensor>::create({t.mImpl->backend(), dataType()})(*this); // Perform a shallow copy only
mImpl->setDevice(t.mImpl->device().second); setImpl(t.mImpl, t.mImplOffset);
// Same backend, same device => directly use copy()
mImpl->copy(t.mImpl->rawPtr(), mSize);
} }
} }
else { else {
mImpl = nullptr; setImpl(nullptr);
} }
return *this; return *this;
} }
...@@ -247,17 +264,15 @@ class Tensor : public Data, ...@@ -247,17 +264,15 @@ class Tensor : public Data,
if (mImpl->device() != std::make_pair(name, device)) { if (mImpl->device() != std::make_pair(name, device)) {
// Backend change: create new impl, copy from old to new and replace // Backend change: create new impl, copy from old to new and replace
// impl // impl
std::unique_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({name, mDataType})(*this); std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({name, mDataType})(device, mImpl->size());
newImpl->setDevice(device);
if (copyFrom) { if (copyFrom) {
newImpl->copyFrom(*mImpl, size()); newImpl->copyFrom(*mImpl, mImpl->size(), mImplOffset, 0);
} }
mImpl = std::move(newImpl); setImpl(newImpl);
} }
} }
else { else {
mImpl = Registrar<Tensor>::create({name, mDataType})(*this); mImpl = Registrar<Tensor>::create({name, mDataType})(device, mSize);
mImpl->setDevice(device);
} }
} }
...@@ -287,21 +302,32 @@ class Tensor : public Data, ...@@ -287,21 +302,32 @@ class Tensor : public Data,
*/ */
void setDataType(const DataType dt, bool copyCast = true) { void setDataType(const DataType dt, bool copyCast = true) {
if (mImpl && (dataType() != dt)) { if (mImpl && (dataType() != dt)) {
std::unique_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), dt})(*this); std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), dt})(mImpl->device().second, mImpl->size());
if (copyCast) { if (copyCast) {
newImpl->copyCast(mImpl->rawPtr(), size(), mDataType); newImpl->copyCast(mImpl->rawPtr(mImplOffset), mDataType, mImpl->size());
} }
mImpl = std::move(newImpl); setImpl(newImpl);
} }
mDataType = dt; mDataType = dt;
} }
/** /**
* @brief Get the Impl object * @brief Get the Impl object
* @return constexpr const std::unique_ptr<TensorImpl>& * @return constexpr const std::shared_ptr<TensorImpl>&
*/ */
constexpr const std::unique_ptr<TensorImpl> &getImpl() { return mImpl; } constexpr const std::shared_ptr<TensorImpl> &getImpl() const { return mImpl; }
constexpr const std::unique_ptr<TensorImpl> &getImpl() const { return mImpl; } constexpr std::size_t getImplOffset() const { return mImplOffset; }
/**
* @brief Set the Impl object
*
* @param impl New impl shared pointer
* @param implOffset Storage offset in this new impl for this Tensor
*/
void setImpl(std::shared_ptr<TensorImpl> impl, std::size_t implOffset = 0) {
mImpl = impl;
mImplOffset = implOffset;
}
/** /**
* @brief Return if an implementaiton has been associated. * @brief Return if an implementaiton has been associated.
...@@ -333,6 +359,18 @@ class Tensor : public Data, ...@@ -333,6 +359,18 @@ class Tensor : public Data,
*/ */
constexpr const std::vector<DimSize_t> &dims() const { return mDims; } constexpr const std::vector<DimSize_t> &dims() const { return mDims; }
/**
* @brief Get strides of the Tensor object.
* @return constexpr const std::vector<DimSize_t>&
*/
constexpr const std::vector<DimSize_t> &strides() const { return mStrides; }
/**
* @brief Return true if Tensor is contiguous in memory.
* @return bool
*/
constexpr bool isContiguous() const { return mContiguous; }
/** /**
* @brief Get the number of elements in the Tensor object. * @brief Get the number of elements in the Tensor object.
* @return constexpr std::size_t * @return constexpr std::size_t
...@@ -364,10 +402,49 @@ class Tensor : public Data, ...@@ -364,10 +402,49 @@ class Tensor : public Data,
* one, all previous data is invalided. Otherwise, previous data may or may * one, all previous data is invalided. Otherwise, previous data may or may
* not remain valid, depending on the backend implementation. * not remain valid, depending on the backend implementation.
* @param dims New dimensions * @param dims New dimensions
* @param strides Stride of the tensor (if not specified, "nested" stride is used)
*/ */
void resize(const std::vector<DimSize_t> &dims) { void resize(const std::vector<DimSize_t> &dims, std::vector<DimSize_t> strides = std::vector<DimSize_t>()) {
mDims = dims; bool checkContiguous = true;
computeSize(); if (strides.empty()) {
strides.resize(dims.size());
size_t expectedStride = 1;
for (int dim = dims.size() - 1; dim >= 0; --dim) {
strides[dim] = expectedStride;
expectedStride*= dims[dim];
}
checkContiguous = false;
}
else {
AIDGE_ASSERT(strides.size() == dims.size(), "Number of strides must match number of dims");
}
if (mImpl.use_count() > 1) {
// Here we could also create a new storage for this tensor in this case
// But, is it more likely that the user really wants this, or that he did a mistake?
AIDGE_ASSERT(dims == mDims && strides == mStrides, "Cannot resize Tensor with shared storage");
}
else {
mDims = dims;
mStrides = strides;
mContiguous = true;
if (checkContiguous) {
size_t expectedStride = 1;
for (int dim = dims.size() - 1; dim >= 0; --dim) {
if (strides[dim] != expectedStride) {
mContiguous = false;
break;
}
expectedStride*= dims[dim];
}
}
computeSize();
if (mImpl) {
mImpl->resize(mSize);
}
}
} }
/** /**
...@@ -381,25 +458,25 @@ class Tensor : public Data, ...@@ -381,25 +458,25 @@ class Tensor : public Data,
const expectedType& get(std::size_t idx) const { const expectedType& get(std::size_t idx) const {
AIDGE_ASSERT(NativeType<expectedType>::type == mDataType, "wrong data type"); AIDGE_ASSERT(NativeType<expectedType>::type == mDataType, "wrong data type");
AIDGE_ASSERT(idx < mSize, "idx out of range"); AIDGE_ASSERT(idx < mSize, "idx out of range");
return *reinterpret_cast<expectedType *>(mImpl->hostPtr(idx)); return *reinterpret_cast<expectedType *>(mImpl->hostPtr(mImplOffset + idx));
} }
template <typename expectedType> template <typename expectedType>
const expectedType& get(std::vector<std::size_t> coordIdx) const { const expectedType& get(std::vector<std::size_t> coordIdx) const {
return get<expectedType>(getIdx(coordIdx)); return get<expectedType>(getStorageIdx(coordIdx));
} }
template <typename expectedType> template <typename expectedType>
void set(std::size_t idx, expectedType value){ void set(std::size_t idx, expectedType value){
AIDGE_ASSERT(NativeType<expectedType>::type == mDataType, "wrong data type"); AIDGE_ASSERT(NativeType<expectedType>::type == mDataType, "wrong data type");
AIDGE_ASSERT(idx < mSize, "idx out of range"); AIDGE_ASSERT(idx < mSize, "idx out of range");
expectedType* dataPtr = static_cast<expectedType*>(mImpl->hostPtr(idx)); expectedType* dataPtr = static_cast<expectedType*>(mImpl->hostPtr(mImplOffset + idx));
*dataPtr = value; *dataPtr = value;
} }
template <typename expectedType> template <typename expectedType>
void set(std::vector<std::size_t> coordIdx, expectedType value){ void set(std::vector<std::size_t> coordIdx, expectedType value){
set<expectedType>(getIdx(coordIdx), value); set<expectedType>(getStorageIdx(coordIdx), value);
} }
...@@ -463,9 +540,9 @@ class Tensor : public Data, ...@@ -463,9 +540,9 @@ class Tensor : public Data,
for (; dimVals[dim] < static_cast<std::size_t>(dims()[dim]); ++dimVals[dim]) { for (; dimVals[dim] < static_cast<std::size_t>(dims()[dim]); ++dimVals[dim]) {
res += spaceString + "{"; res += spaceString + "{";
for (DimSize_t j = 0; j < dims()[dim + 1] - 1; ++j) { for (DimSize_t j = 0; j < dims()[dim + 1] - 1; ++j) {
res += " " + ptrToString(mDataType, mImpl->hostPtr(), counter++) + ","; res += " " + ptrToString(mDataType, mImpl->hostPtr(mImplOffset), counter++) + ",";
} }
res += " " + ptrToString(mDataType, mImpl->hostPtr(), counter++) + "}"; res += " " + ptrToString(mDataType, mImpl->hostPtr(mImplOffset), counter++) + "}";
if (dimVals[dim] < static_cast<std::size_t>(dims()[dim] - 1)) { if (dimVals[dim] < static_cast<std::size_t>(dims()[dim] - 1)) {
res += ","; res += ",";
} }
...@@ -485,7 +562,7 @@ class Tensor : public Data, ...@@ -485,7 +562,7 @@ class Tensor : public Data,
} else { } else {
res += "{"; res += "{";
for (DimSize_t j = 0; j < dims()[0]; ++j) { for (DimSize_t j = 0; j < dims()[0]; ++j) {
res += " " + ptrToString(mDataType, mImpl->hostPtr(), j) + ((j < dims()[0]-1) ? "," : " "); res += " " + ptrToString(mDataType, mImpl->hostPtr(mImplOffset), j) + ((j < dims()[0]-1) ? "," : " ");
} }
} }
res += "}"; res += "}";
...@@ -507,6 +584,7 @@ class Tensor : public Data, ...@@ -507,6 +584,7 @@ class Tensor : public Data,
/** /**
* @brief From the the 1D contiguous index, return the coordinate of an element in the tensor. * @brief From the the 1D contiguous index, return the coordinate of an element in the tensor.
* Beware: do not use this function with the storage index!
* *
* @param flatIdx 1D contiguous index of the value considering a flatten, contiguous, tensor. * @param flatIdx 1D contiguous index of the value considering a flatten, contiguous, tensor.
* @return std::vector<DimSize_t> * @return std::vector<DimSize_t>
...@@ -526,6 +604,8 @@ class Tensor : public Data, ...@@ -526,6 +604,8 @@ class Tensor : public Data,
* @brief From the coordinate returns the 1D contiguous index of an element in the tensor. * @brief From the coordinate returns the 1D contiguous index of an element in the tensor.
* If the number of coordinates is inferior to the number of dimensions, * If the number of coordinates is inferior to the number of dimensions,
* the remaining coordinates are assumed to be 0. * the remaining coordinates are assumed to be 0.
* Beware: the contiguous index will only correspond to the storage index
* if the tensor is contiguous!
* *
* @param coordIdx Coordinate to an element in the tensor * @param coordIdx Coordinate to an element in the tensor
* @return DimSize_t Contiguous index * @return DimSize_t Contiguous index
...@@ -541,6 +621,51 @@ class Tensor : public Data, ...@@ -541,6 +621,51 @@ class Tensor : public Data,
return flatIdx + coordIdx[i]; return flatIdx + coordIdx[i];
} }
/**
* @brief From the coordinate returns the 1D storage index of an element in the tensor.
* If the number of coordinates is inferior to the number of dimensions,
* the remaining coordinates are assumed to be 0.
*
* @param coordIdx Coordinate to an element in the tensor
* @return DimSize_t Storage index
*/
std::size_t getStorageIdx(const std::vector<std::size_t>& coordIdx) const {
AIDGE_ASSERT(coordIdx.size() <= mDims.size(), "Coordinates does not match number of dimensions");
return std::inner_product(coordIdx.begin(), coordIdx.end(), mStrides.begin(), DimSize_t(0));
}
/**
* Returns a sub-tensor with one or more dimension less.
* For instance, t.extract({1}) on a CHW tensor will return the HW tensor
* of channel #1.
* Likewise, t.extract({0, 1}) on a NCHW tensor will return the HW tensor
* of batch #0 and channel #1.
* No memory copy is performed, the returned tensor does not own the memory.
* If the number of coordinates matches the number of dimensions, an empty
* tensor is returned.
* It current tensor was contiguous, the returned tensor is garanteed to be
* contiguous as well.
*
* @param coordIdx Coordinates of the sub-tensor to extract
* @return Tensor Sub-tensor.
*/
Tensor extract(const std::vector<std::size_t>& coordIdx) const;
/**
* Returns a sub-tensor at some coordinate and with some dimension.
*
* @param coordIdx First coordinates of the sub-tensor to extract
* @param dims Dimensions of the sub-tensor to extract
* @return Tensor Sub-tensor.
*/
Tensor extract(const std::vector<std::size_t>& coordIdx, const std::vector<std::size_t>& dims) const;
/**
* Make the tensor's storage contiguous, if it is not already the case.
* If not contiguous, a new memory space is allocated.
*/
void makeContiguous();
/** /**
* Copy-cast data from a Tensor on the same device. * Copy-cast data from a Tensor on the same device.
* If current tensor backend/device is set and is different from src, an * If current tensor backend/device is set and is different from src, an
...@@ -586,6 +711,20 @@ class Tensor : public Data, ...@@ -586,6 +711,20 @@ class Tensor : public Data,
copyCastFrom(src, movedSrc); copyCastFrom(src, movedSrc);
} }
/**
* Return a reference to a Tensor that is garanteed to be contiguous:
* - itself, if already contiguous;
* - the provided Tensor, overwritten with the copied data.
* The data type, backend and device stay the same.
* @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
* The shared_ptr does not need to be initialized. No new memory allocation
* will occur if fallback has already been allocated with the right
* type/size/device.
* @return Reference to either itself or to fallback.
*/
Tensor& refContiguous(std::shared_ptr<Tensor>& fallback);
const Tensor& refContiguous(std::shared_ptr<Tensor>& fallback) const;
/** /**
* Return a reference to a Tensor casted to the desired data type: * Return a reference to a Tensor casted to the desired data type:
* - itself, if already at the right data type; * - itself, if already at the right data type;
...@@ -656,6 +795,43 @@ class Tensor : public Data, ...@@ -656,6 +795,43 @@ class Tensor : public Data,
return refCastFrom(fallback, targetReqs.dataType(), device.first, device.second); return refCastFrom(fallback, targetReqs.dataType(), device.first, device.second);
} }
/**
* Return a reference to a Tensor on desired data type and backend/device:
* - itself, if already with the right characteristics;
* - the provided Tensor, overwritten with the right characteristics.
* NOTE: no data is copy-casted. If it was so in a previous refCastFrom() on
* the same fallback, it remains valid, otherwise, data is invalid.
* @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
* The shared_ptr does not need to be initialized. No new memory allocation
* will occur if fallback has already been allocated with the right
* type/size/device.
* @param dt The desired data type.
* @param backend The desired backend.
* @param device The desired device.
* @return Reference to either itself or to fallback.
*/
Tensor& ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device = 0);
const Tensor& ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device = 0) const;
/**
* Return a reference to a Tensor with same characteristics
* (data type, backend/device) as targetReqs Tensor:
* - itself, if already with the right characteristics;
* - the provided Tensor, overwritten with the right characteristics.
* NOTE: no data is copy-casted. If it was so in a previous refCastFrom() on
* the same fallback, it remains valid, otherwise, data is invalid.
* @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
* The shared_ptr does not need to be initialized. No new memory allocation
* will occur if fallback has already been allocated with the right
* type/size/device.
* @param targetReqs Tensor with the desired target characteristics.
* @return Reference to either itself or to fallback.
*/
Tensor& ref(std::shared_ptr<Tensor>& fallback, const Tensor& targetReqs) {
const auto& device = targetReqs.getImpl()->device();
return ref(fallback, targetReqs.dataType(), device.first, device.second);
}
private: private:
///\bug not protected against overflow ///\bug not protected against overflow
void computeSize() { void computeSize() {
......
...@@ -30,7 +30,7 @@ void addCtor(py::class_<Tensor, ...@@ -30,7 +30,7 @@ void addCtor(py::class_<Tensor,
Data, Data,
Registrable<Tensor, Registrable<Tensor,
std::tuple<std::string, DataType>, std::tuple<std::string, DataType>,
std::unique_ptr<TensorImpl>(const Tensor&)>>& mTensor){ std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>>& mTensor){
mTensor.def(py::init([]( py::array_t<T, py::array::c_style | py::array::forcecast> b) { mTensor.def(py::init([]( py::array_t<T, py::array::c_style | py::array::forcecast> b) {
/* Request a buffer descriptor from Python */ /* Request a buffer descriptor from Python */
py::buffer_info info = b.request(); py::buffer_info info = b.request();
...@@ -58,16 +58,16 @@ void addCtor(py::class_<Tensor, ...@@ -58,16 +58,16 @@ void addCtor(py::class_<Tensor,
void init_Tensor(py::module& m){ void init_Tensor(py::module& m){
py::class_<Registrable<Tensor, py::class_<Registrable<Tensor,
std::tuple<std::string, DataType>, std::tuple<std::string, DataType>,
std::unique_ptr<TensorImpl>(const Tensor&)>, std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>,
std::shared_ptr<Registrable<Tensor, std::shared_ptr<Registrable<Tensor,
std::tuple<std::string, DataType>, std::tuple<std::string, DataType>,
std::unique_ptr<TensorImpl>(const Tensor&)>>>(m,"TensorRegistrable"); std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>>>(m,"TensorRegistrable");
py::class_<Tensor, std::shared_ptr<Tensor>, py::class_<Tensor, std::shared_ptr<Tensor>,
Data, Data,
Registrable<Tensor, Registrable<Tensor,
std::tuple<std::string, DataType>, std::tuple<std::string, DataType>,
std::unique_ptr<TensorImpl>(const Tensor&)>> pyClassTensor std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>> pyClassTensor
(m,"Tensor", py::multiple_inheritance(), py::buffer_protocol()); (m,"Tensor", py::multiple_inheritance(), py::buffer_protocol());
pyClassTensor.def(py::init<>()) pyClassTensor.def(py::init<>())
...@@ -76,7 +76,7 @@ void init_Tensor(py::module& m){ ...@@ -76,7 +76,7 @@ void init_Tensor(py::module& m){
.def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims) .def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims)
.def("dtype", &Tensor::dataType) .def("dtype", &Tensor::dataType)
.def("size", &Tensor::size) .def("size", &Tensor::size)
.def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&)) &Tensor::resize) .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize)
.def("has_impl", &Tensor::hasImpl) .def("has_impl", &Tensor::hasImpl)
.def("get_coord", &Tensor::getCoord) .def("get_coord", &Tensor::getCoord)
.def("get_idx", &Tensor::getIdx) .def("get_idx", &Tensor::getIdx)
...@@ -118,7 +118,7 @@ void init_Tensor(py::module& m){ ...@@ -118,7 +118,7 @@ void init_Tensor(py::module& m){
} }
}) })
.def_buffer([](Tensor& b) -> py::buffer_info { .def_buffer([](Tensor& b) -> py::buffer_info {
const std::unique_ptr<TensorImpl>& tensorImpl = b.getImpl(); const std::shared_ptr<TensorImpl>& tensorImpl = b.getImpl();
std::vector<size_t> dims; std::vector<size_t> dims;
std::vector<size_t> strides; std::vector<size_t> strides;
......
...@@ -14,23 +14,23 @@ ...@@ -14,23 +14,23 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
void Aidge::TensorImpl::copyFrom(const TensorImpl& srcImpl, NbElts_t length) { void Aidge::TensorImpl::copyFrom(const TensorImpl& srcImpl, NbElts_t length, NbElts_t srcOffset, NbElts_t dstOffset) {
if (&srcImpl == this) { if (&srcImpl == this && srcOffset == dstOffset) {
return; return;
} }
if (srcImpl.device() != device()) { if (srcImpl.device() != device()) {
if (srcImpl.backend() == backend()) { if (srcImpl.backend() == backend()) {
// Same backend, but different device // Same backend, but different device
copyFromDevice(srcImpl.rawPtr(), length, srcImpl.device()); copyFromDevice(srcImpl.rawPtr(srcOffset), srcImpl.device(), length, dstOffset);
} }
else if (srcImpl.hostPtr() != nullptr) { else if (srcImpl.hostPtr() != nullptr) {
// Different backend, but input is valid on host // Different backend, but input is valid on host
copyFromHost(srcImpl.hostPtr(), length); copyFromHost(srcImpl.hostPtr(srcOffset), length, dstOffset);
} }
else if (hostPtr() != nullptr) { else if (hostPtr() != nullptr) {
// Different backend, but dst is valid on host // Different backend, but dst is valid on host
srcImpl.copyToHost(hostPtr(), length); srcImpl.copyToHost(hostPtr(srcOffset), length, dstOffset);
} }
else { else {
// No direct link possible from src to dst device // No direct link possible from src to dst device
...@@ -40,12 +40,12 @@ void Aidge::TensorImpl::copyFrom(const TensorImpl& srcImpl, NbElts_t length) { ...@@ -40,12 +40,12 @@ void Aidge::TensorImpl::copyFrom(const TensorImpl& srcImpl, NbElts_t length) {
// - There is currently no concrete use case // - There is currently no concrete use case
// - Just providing a pointer would be unsafe (risk of buffer overflow...) // - Just providing a pointer would be unsafe (risk of buffer overflow...)
auto tmpHostBuffer = std::unique_ptr<char[]>(new char[scalarSize() * length]); auto tmpHostBuffer = std::unique_ptr<char[]>(new char[scalarSize() * length]);
srcImpl.copyToHost(tmpHostBuffer.get(), length); srcImpl.copyToHost(tmpHostBuffer.get(), length, srcOffset);
copyFromHost(tmpHostBuffer.get(), length); copyFromHost(tmpHostBuffer.get(), length, dstOffset);
} }
} }
else { else {
// Same device: simple copy on device // Same device: simple copy on device
copy(srcImpl.rawPtr(), length); copy(srcImpl.rawPtr(srcOffset), length, dstOffset);
} }
} }
...@@ -13,11 +13,72 @@ ...@@ -13,11 +13,72 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
Aidge::Tensor Aidge::Tensor::extract(const std::vector<std::size_t>& coordIdx) const {
AIDGE_ASSERT(isContiguous(), "Tensor must be contiguous");
AIDGE_ASSERT(coordIdx.size() <= mDims.size(), "Number of coordinates is higher than number of dimensions");
Tensor subTensor(mDataType);
subTensor.resize(std::vector<size_t>(mDims.begin() + coordIdx.size(), mDims.end()),
std::vector<size_t>(mStrides.begin() + coordIdx.size(), mStrides.end()));
subTensor.setBackend(mImpl->backend(), mImpl->device().second);
subTensor.setImpl(mImpl, mImplOffset + getStorageIdx(coordIdx));
return subTensor;
}
Aidge::Tensor Aidge::Tensor::extract(const std::vector<std::size_t>& coordIdx, const std::vector<std::size_t>& dims) const {
AIDGE_ASSERT(isContiguous(), "Tensor must be contiguous");
AIDGE_ASSERT(coordIdx.size() == mDims.size(), "Coordinates does not match number of dimensions");
Tensor subTensor(mDataType);
subTensor.resize(dims, mStrides);
subTensor.setBackend(mImpl->backend(), mImpl->device().second);
subTensor.setImpl(mImpl, mImplOffset + getStorageIdx(coordIdx));
return subTensor;
}
void Aidge::Tensor::makeContiguous() {
if (!mImpl || isContiguous()) {
return;
}
// Block so that mImpl ref count is 1 for resize()
{
// Create a new storage that will be contiguous
std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), mDataType})(mImpl->device().second, mSize);
// Copy elements from old to new storage
size_t idx = 0;
while (idx < mSize) {
const size_t storageIdx = getStorageIdx(getCoord(idx));
// Determine the size of the contiguous chunk
size_t copySize = 1;
while (idx + copySize < mSize &&
getStorageIdx(getCoord(idx + copySize)) == storageIdx + copySize)
{
++copySize;
}
// Perform a single copy for the contiguous chunk
newImpl->copy(mImpl->rawPtr(mImplOffset + storageIdx), copySize, idx);
// Move to the next index after the contiguous chunk
idx += copySize;
}
// Replace old storage by new, contiguous, storage
setImpl(newImpl);
}
// Resize tensor without strides => tensor is now contiguous
resize(mDims);
}
void Aidge::Tensor::copyCast(const Tensor& src) { void Aidge::Tensor::copyCast(const Tensor& src) {
if (&src == this) { if (&src == this) {
return; return;
} }
AIDGE_ASSERT(src.isContiguous(), "cannot copy-cast non-contiguous tensor");
// Current Tensor has necessarily a data type, but may not have backend // Current Tensor has necessarily a data type, but may not have backend
if (!getImpl()) { if (!getImpl()) {
// If no backend was set for the current tensor, use the same as src // If no backend was set for the current tensor, use the same as src
...@@ -27,7 +88,7 @@ void Aidge::Tensor::copyCast(const Tensor& src) { ...@@ -27,7 +88,7 @@ void Aidge::Tensor::copyCast(const Tensor& src) {
resize(src.dims()); resize(src.dims());
AIDGE_ASSERT(src.getImpl()->device() == getImpl()->device(), "cannot copy-cast from a different backend/device"); AIDGE_ASSERT(src.getImpl()->device() == getImpl()->device(), "cannot copy-cast from a different backend/device");
getImpl()->copyCast(src.getImpl()->rawPtr(), src.size(), src.dataType()); getImpl()->copyCast(src.getImpl()->rawPtr(src.mImplOffset), src.dataType(), src.size(), mImplOffset);
} }
void Aidge::Tensor::copyFrom(const Tensor& src) { void Aidge::Tensor::copyFrom(const Tensor& src) {
...@@ -35,6 +96,8 @@ void Aidge::Tensor::copyFrom(const Tensor& src) { ...@@ -35,6 +96,8 @@ void Aidge::Tensor::copyFrom(const Tensor& src) {
return; return;
} }
AIDGE_ASSERT(src.isContiguous(), "cannot copy from non-contiguous tensor");
// Current Tensor has necessarily a data type, but may not have backend // Current Tensor has necessarily a data type, but may not have backend
if (!getImpl()) { if (!getImpl()) {
// If no backend was set for the current tensor, use the same as src // If no backend was set for the current tensor, use the same as src
...@@ -44,7 +107,7 @@ void Aidge::Tensor::copyFrom(const Tensor& src) { ...@@ -44,7 +107,7 @@ void Aidge::Tensor::copyFrom(const Tensor& src) {
resize(src.dims()); resize(src.dims());
AIDGE_ASSERT(src.dataType() == dataType(), "cannot copy from a different data type"); AIDGE_ASSERT(src.dataType() == dataType(), "cannot copy from a different data type");
getImpl()->copyFrom(*(src.getImpl()), src.size()); getImpl()->copyFrom(*(src.getImpl()), src.size(), src.mImplOffset, mImplOffset);
} }
void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& movedSrcPtr) { void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& movedSrcPtr) {
...@@ -52,6 +115,8 @@ void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& mov ...@@ -52,6 +115,8 @@ void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& mov
return; return;
} }
AIDGE_ASSERT(src.isContiguous(), "cannot copy-cast from non-contiguous tensor");
// Current Tensor has necessarily a data type, but may not have backend // Current Tensor has necessarily a data type, but may not have backend
if (!getImpl()) { if (!getImpl()) {
// If no backend was set for the current tensor, use the same as src // If no backend was set for the current tensor, use the same as src
...@@ -65,12 +130,35 @@ void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& mov ...@@ -65,12 +130,35 @@ void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& mov
const auto device = getImpl()->device(); const auto device = getImpl()->device();
const Tensor& movedSrc = src.refFrom(movedSrcPtr, device.first, device.second); const Tensor& movedSrc = src.refFrom(movedSrcPtr, device.first, device.second);
// Second, copy-cast data (necessary) // Second, copy-cast data (necessary)
getImpl()->copyCast(movedSrc.getImpl()->rawPtr(), movedSrc.size(), movedSrc.dataType()); getImpl()->copyCast(movedSrc.getImpl()->rawPtr(movedSrc.mImplOffset), movedSrc.dataType(), movedSrc.size(), mImplOffset);
} }
else { else {
// Directly copy, no conversion necessary // Directly copy, no conversion necessary
// Avoid making a double copy if both data type and device are the same // Avoid making a double copy if both data type and device are the same
getImpl()->copyFrom(*(src.getImpl()), src.size()); getImpl()->copyFrom(*(src.getImpl()), src.size(), src.mImplOffset, mImplOffset);
}
}
Aidge::Tensor& Aidge::Tensor::refContiguous(std::shared_ptr<Tensor>& fallback) {
// Scott Meyers' solution to avoid code duplication
return const_cast<Tensor&>(static_cast<const Tensor&>(*this).refContiguous(fallback));
}
const Aidge::Tensor& Aidge::Tensor::refContiguous(std::shared_ptr<Tensor>& fallback) const {
AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot refCast() it");
if (isContiguous()) {
return *this;
}
else {
if (this != fallback.get()) {
// Shallow copy to fallback
*fallback = *this;
}
// Make fallback contiguous
fallback->makeContiguous();
return *fallback;
} }
} }
...@@ -91,6 +179,8 @@ const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, c ...@@ -91,6 +179,8 @@ const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, c
fallback->setDataType(dt); fallback->setDataType(dt);
} }
else { else {
AIDGE_ASSERT(isContiguous(), "cannot refCast non-contiguous tensor");
if (!fallback) { if (!fallback) {
fallback = std::make_shared<Tensor>(dt); fallback = std::make_shared<Tensor>(dt);
} }
...@@ -101,7 +191,7 @@ const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, c ...@@ -101,7 +191,7 @@ const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, c
const auto device = getImpl()->device(); const auto device = getImpl()->device();
fallback->setBackend(device.first, device.second, false); // don't keep previous data (no copy) fallback->setBackend(device.first, device.second, false); // don't keep previous data (no copy)
fallback->resize(dims()); fallback->resize(dims());
fallback->getImpl()->copyCast(getImpl()->rawPtr(), size(), dataType()); fallback->getImpl()->copyCast(getImpl()->rawPtr(mImplOffset), dataType(), size(), fallback->mImplOffset);
} }
return *fallback; return *fallback;
} }
...@@ -124,6 +214,8 @@ const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, c ...@@ -124,6 +214,8 @@ const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, c
fallback->setBackend(backend, device); fallback->setBackend(backend, device);
} }
else { else {
AIDGE_ASSERT(isContiguous(), "cannot refFrom non-contiguous tensor");
if (!fallback) { if (!fallback) {
fallback = std::make_shared<Tensor>(dataType()); fallback = std::make_shared<Tensor>(dataType());
} }
...@@ -133,8 +225,34 @@ const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, c ...@@ -133,8 +225,34 @@ const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, c
fallback->setBackend(backend, device, false); // don't keep previous data (no copy) fallback->setBackend(backend, device, false); // don't keep previous data (no copy)
fallback->resize(dims()); fallback->resize(dims());
fallback->getImpl()->copyFrom(*getImpl(), size()); fallback->getImpl()->copyFrom(*getImpl(), size(), mImplOffset, fallback->mImplOffset);
}
return *fallback;
}
}
Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device) {
// Scott Meyers' solution to avoid code duplication
return const_cast<Tensor&>(static_cast<const Tensor&>(*this).ref(fallback, dt, backend, device));
}
const Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device) const {
AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot ref() it");
if (dt == dataType() && std::make_pair(backend, device) == getImpl()->device()) {
return *this;
}
else {
// Change fallback type, backend & device, without any data copy
if (!fallback) {
fallback = std::make_shared<Tensor>(dt);
} }
else {
fallback->setDataType(dt, false); // don't keep previous data (no copy)
}
fallback->setBackend(backend, device, false); // don't keep previous data (no copy)
fallback->resize(dims());
return *fallback; return *fallback;
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment