Skip to content
Snippets Groups Projects
Commit 77d70664 authored by Maxence Naud's avatar Maxence Naud Committed by Maxence Naud
Browse files

Improve copy/move/clone behaviour consistency

- Add move assignment operator
- Fix 'Tensor::clone()' if original Tensor has no implementation
- Change behaviour to always perform a shallow copy in case of copy constructor/assignment operator calls
parent 7e4e784d
No related branches found
No related tags found
2 merge requests!279v0.4.0,!273[Fix] Producer clone and Tensor copy
...@@ -212,14 +212,13 @@ class Tensor : public Data, ...@@ -212,14 +212,13 @@ class Tensor : public Data,
/** /**
* @brief Copy dimensions, datatype and data from another Tensor. * @brief Copy dimensions, datatype and data from another Tensor.
* If current Tensor already has an implementation, data is copied to the * Tensor backend/device are also copied and only a shallow copy
* existing implementation. Tensor backend/device remain untouched. * is performed for data. Implementation will be shared with original Tensor.
* If current Tensor does not have an implementation, only a shallow copy
* is performed and the Tensor will share data with t.
* @param other other Tensor object. * @param other other Tensor object.
* @return Tensor& * @return Tensor&
*/ */
Tensor &operator=(const Tensor& other); Tensor &operator=(const Tensor& other) = default;
Tensor &operator=(Tensor&& other) = default;
template <typename T> template <typename T>
constexpr Tensor &operator=(Vector<T> &&arr) { constexpr Tensor &operator=(Vector<T> &&arr) {
...@@ -332,14 +331,17 @@ public: ...@@ -332,14 +331,17 @@ public:
* @brief Perform a deep copy of the tensor. * @brief Perform a deep copy of the tensor.
*/ */
Tensor clone() const { Tensor clone() const {
Tensor newTensor(*this); Tensor newTensor(*this); // shallow copy
if (!newTensor.isContiguous()) { // handle deepcopy of implementation if any
newTensor.makeContiguous(); if (newTensor.hasImpl()) {
} if (!newTensor.isContiguous()) {
else { newTensor.makeContiguous();
std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), mDataType})(mImpl->device().second, mDims); }
newImpl->copy(mImpl->rawPtr(mImplOffset), mSize); else {
newTensor.setImpl(newImpl); std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), mDataType})(mImpl->device().second, mDims);
newImpl->copy(mImpl->rawPtr(mImplOffset), mSize);
newTensor.setImpl(newImpl);
}
} }
return newTensor; return newTensor;
} }
......
...@@ -135,24 +135,24 @@ Tensor Tensor::mean() const { ...@@ -135,24 +135,24 @@ Tensor Tensor::mean() const {
return mean_.getOutput(0)->clone(); return mean_.getOutput(0)->clone();
} }
Tensor& Tensor::operator=(const Tensor& other) { // Tensor& Tensor::operator=(const Tensor& other) {
if (this == &other) { // if (this == &other) {
return *this; // return *this;
} // }
resize(other.dims(), other.strides()); // resize(other.dims(), other.strides());
setDataType(other.dataType(), false); // do not convert existing data // setDataType(other.dataType(), false); // do not convert existing data
if (other.hasImpl()) { // if (other.hasImpl()) {
if (hasImpl()) { // if (hasImpl()) {
copyFrom(other); // // copyFrom(other);
} else { // // } else {
// Perform a shallow copy only // // Perform a shallow copy only
setImpl(other.mImpl, other.mImplOffset); // setImpl(other.mImpl, other.mImplOffset);
} // }
} else { // } else {
setImpl(nullptr); // setImpl(nullptr);
} // }
return *this; // return *this;
} // }
void Tensor::setBackend(const std::string &name, DeviceIdx_t device, bool copyFrom) { void Tensor::setBackend(const std::string &name, DeviceIdx_t device, bool copyFrom) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment