Skip to content
Snippets Groups Projects
Commit 15e50d6b authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Fix toString method.

parent 41a9154c
No related branches found
No related tags found
2 merge requests!1190.2.1,!109Add template_docstring decorator which allow to template Python docstring.
...@@ -23,29 +23,26 @@ Aidge::Tensor& Aidge::Tensor::operator=(const Aidge::Tensor& other) { ...@@ -23,29 +23,26 @@ Aidge::Tensor& Aidge::Tensor::operator=(const Aidge::Tensor& other) {
return *this; return *this;
} }
resize(other.dims(), other.strides()); resize(other.dims(), other.strides());
setDataType(other.dataType(), false); // do not convert existing data setDataType(other.dataType(), false); // do not convert existing data
if (other.hasImpl()) { if (other.hasImpl()) {
if (hasImpl()) { if (hasImpl()) {
copyFrom(other); copyFrom(other);
} } else {
else {
// Perform a shallow copy only // Perform a shallow copy only
setImpl(other.mImpl, other.mImplOffset); setImpl(other.mImpl, other.mImplOffset);
} }
} } else {
else {
setImpl(nullptr); setImpl(nullptr);
} }
return *this; return *this;
} }
Aidge::Tensor::~Tensor() noexcept = default; Aidge::Tensor::~Tensor() noexcept = default;
void Aidge::Tensor::resize(const std::vector<Aidge::DimSize_t>& dims,
void Aidge::Tensor::resize(const std::vector<Aidge::DimSize_t> &dims, std::vector<Aidge::DimSize_t> strides) { std::vector<Aidge::DimSize_t> strides) {
// TODO: scalar Tensor not handled // TODO: scalar Tensor not handled
if (dims.empty()) { // scalar if (dims.empty()) { // scalar
mDims = std::vector<DimSize_t>(0); mDims = std::vector<DimSize_t>(0);
mStrides = std::vector<DimSize_t>({1}); mStrides = std::vector<DimSize_t>({1});
mContiguous = true; mContiguous = true;
...@@ -63,20 +60,21 @@ void Aidge::Tensor::resize(const std::vector<Aidge::DimSize_t> &dims, std::vecto ...@@ -63,20 +60,21 @@ void Aidge::Tensor::resize(const std::vector<Aidge::DimSize_t> &dims, std::vecto
size_t expectedStride = 1; size_t expectedStride = 1;
for (int dim = dims.size() - 1; dim >= 0; --dim) { for (int dim = dims.size() - 1; dim >= 0; --dim) {
strides[dim] = expectedStride; strides[dim] = expectedStride;
expectedStride*= dims[dim]; expectedStride *= dims[dim];
} }
checkContiguous = false; checkContiguous = false;
} } else {
else { AIDGE_ASSERT(strides.size() == dims.size(),
AIDGE_ASSERT(strides.size() == dims.size(), "Number of strides must match number of dims"); "Number of strides must match number of dims");
} }
if (mImpl && mImpl.use_count() > 1) { if (mImpl && mImpl.use_count() > 1) {
// Here we could also create a new storage for this tensor in this case // Here we could also create a new storage for this tensor in this case
// But, is it more likely that the user really wants this, or that he did a mistake? // But, is it more likely that the user really wants this, or that he
AIDGE_ASSERT(dims == mDims && strides == mStrides, "Cannot resize Tensor with shared storage"); // did a mistake?
} AIDGE_ASSERT(dims == mDims && strides == mStrides,
else { "Cannot resize Tensor with shared storage");
} else {
mDims = dims; mDims = dims;
mStrides = strides; mStrides = strides;
...@@ -88,12 +86,12 @@ void Aidge::Tensor::resize(const std::vector<Aidge::DimSize_t> &dims, std::vecto ...@@ -88,12 +86,12 @@ void Aidge::Tensor::resize(const std::vector<Aidge::DimSize_t> &dims, std::vecto
// mContiguous&= (strides[i] == expectedStride); // mContiguous&= (strides[i] == expectedStride);
// expectedStride*= dims[i]; // expectedStride*= dims[i];
// } // }
for (std::size_t i = dims.size()-1; i > 0; --i) { for (std::size_t i = dims.size() - 1; i > 0; --i) {
if (strides[i] != expectedStride) { if (strides[i] != expectedStride) {
mContiguous = false; mContiguous = false;
break; break;
} }
expectedStride*= dims[i]; expectedStride *= dims[i];
} }
mContiguous &= (strides[0] == expectedStride); mContiguous &= (strides[0] == expectedStride);
} }
...@@ -106,53 +104,59 @@ void Aidge::Tensor::resize(const std::vector<Aidge::DimSize_t> &dims, std::vecto ...@@ -106,53 +104,59 @@ void Aidge::Tensor::resize(const std::vector<Aidge::DimSize_t> &dims, std::vecto
} }
std::string Aidge::Tensor::toString() const { std::string Aidge::Tensor::toString() const {
AIDGE_ASSERT(mImpl && (dims().empty() || (dims() == std::vector<DimSize_t>({0})) || (mImpl->hostPtr() != nullptr)), "tensor should have a valid host pointer"); AIDGE_ASSERT(
mImpl && (dims().empty() || (dims() == std::vector<DimSize_t>({0})) ||
(mImpl->hostPtr() != nullptr)),
"tensor should have a valid host pointer");
// TODO: move lambda elsewhere? // TODO: move lambda elsewhere?
auto ptrToString = [](DataType dt, void* ptr, std::size_t idx) { auto ptrToString = [](DataType dt, void* ptr, std::size_t idx) {
switch (dt) { switch (dt) {
case DataType::Float64: case DataType::Float64:
return std::to_string(static_cast<double*>(ptr)[idx]); return std::to_string(static_cast<double*>(ptr)[idx]);
case DataType::Float32: case DataType::Float32:
return std::to_string(static_cast<float*>(ptr)[idx]); return std::to_string(static_cast<float*>(ptr)[idx]);
case DataType::Float16: case DataType::Float16:
return std::to_string(static_cast<half_float::half*>(ptr)[idx]); return std::to_string(static_cast<half_float::half*>(ptr)[idx]);
case DataType::Int8: case DataType::Int8:
return std::to_string(static_cast<int8_t*>(ptr)[idx]); return std::to_string(static_cast<int8_t*>(ptr)[idx]);
case DataType::Int16: case DataType::Int16:
return std::to_string(static_cast<int16_t*>(ptr)[idx]); return std::to_string(static_cast<int16_t*>(ptr)[idx]);
case DataType::Int32: case DataType::Int32:
return std::to_string(static_cast<int32_t*>(ptr)[idx]); return std::to_string(static_cast<int32_t*>(ptr)[idx]);
case DataType::Int64: case DataType::Int64:
return std::to_string(static_cast<int64_t*>(ptr)[idx]); return std::to_string(static_cast<int64_t*>(ptr)[idx]);
case DataType::UInt8: case DataType::UInt8:
return std::to_string(static_cast<uint8_t*>(ptr)[idx]); return std::to_string(static_cast<uint8_t*>(ptr)[idx]);
case DataType::UInt16: case DataType::UInt16:
return std::to_string(static_cast<uint16_t*>(ptr)[idx]); return std::to_string(static_cast<uint16_t*>(ptr)[idx]);
case DataType::UInt32: case DataType::UInt32:
return std::to_string(static_cast<uint32_t*>(ptr)[idx]); return std::to_string(static_cast<uint32_t*>(ptr)[idx]);
case DataType::UInt64: case DataType::UInt64:
return std::to_string(static_cast<uint64_t*>(ptr)[idx]); return std::to_string(static_cast<uint64_t*>(ptr)[idx]);
default: default:
AIDGE_ASSERT(true, "unsupported type to convert to string"); AIDGE_ASSERT(true, "unsupported type to convert to string");
} }
return std::string("?"); // To make Clang happy return std::string("?"); // To make Clang happy
}; };
if (dims().empty()) { return ptrToString(mDataType, mImpl->hostPtr(), 0); } if (dims().empty()) {
return ptrToString(mDataType, mImpl->hostPtr(), 0);
}
std::string res; std::string res;
std::size_t dim = 0; std::size_t dim = 0;
std::size_t counter = 0; std::size_t counter = 0;
if (nbDims()>=2) { if (nbDims() >= 2) {
std::vector<std::size_t> dimVals(nbDims(), 0); std::vector<std::size_t> dimVals(nbDims(), 0);
res += "{\n"; res += "{\n";
while (counter < mSize) { while (counter < mSize) {
std::string spaceString = std::string((dim+1)<<1,' '); std::string spaceString = std::string((dim + 1) << 1, ' ');
if (dim < nbDims()-2) { if (dim < nbDims() - 2) {
if (dimVals[dim] == 0) { if (dimVals[dim] == 0) {
res += spaceString + "{\n"; res += spaceString + "{\n";
++dim; ++dim;
} else if (dimVals[dim] < static_cast<std::size_t>(dims()[dim])) { } else if (dimVals[dim] <
static_cast<std::size_t>(dims()[dim])) {
res += spaceString + "},\n" + spaceString + "{\n"; res += spaceString + "},\n" + spaceString + "{\n";
++dim; ++dim;
} else { } else {
...@@ -161,13 +165,22 @@ std::string Aidge::Tensor::toString() const { ...@@ -161,13 +165,22 @@ std::string Aidge::Tensor::toString() const {
dimVals[dim]++; dimVals[dim]++;
} }
} else { } else {
for (; dimVals[dim] < static_cast<std::size_t>(dims()[dim]); ++dimVals[dim]) { for (; dimVals[dim] < static_cast<std::size_t>(dims()[dim]);
++dimVals[dim]) {
res += spaceString + "{"; res += spaceString + "{";
for (DimSize_t j = 0; j < dims()[dim + 1] - 1; ++j) { for (DimSize_t j = 0; j < dims()[dim + 1] - 1; ++j) {
res += " " + ptrToString(mDataType, mImpl->hostPtr(mImplOffset), counter++) + ","; res +=
" " +
ptrToString(mDataType, mImpl->hostPtr(mImplOffset),
counter++) +
",";
} }
res += " " + ptrToString(mDataType, mImpl->hostPtr(mImplOffset), counter++) + "}"; res += " " +
if (dimVals[dim] < static_cast<std::size_t>(dims()[dim] - 1)) { ptrToString(mDataType, mImpl->hostPtr(mImplOffset),
counter++) +
"}";
if (dimVals[dim] <
static_cast<std::size_t>(dims()[dim] - 1)) {
res += ","; res += ",";
} }
res += "\n"; res += "\n";
...@@ -179,35 +192,45 @@ std::string Aidge::Tensor::toString() const { ...@@ -179,35 +192,45 @@ std::string Aidge::Tensor::toString() const {
dimVals[dim]++; dimVals[dim]++;
} }
} }
if (nbDims() != 2) { // If nbDims == 2, parenthesis is already closed
for(int i = static_cast<int>(dim); i >= 0; --i) { for (int i = static_cast<int>(dim); i >= 0; --i) {
res += std::string((i+1)<<1,' ') + "}\n"; res += std::string((i + 1) << 1, ' ') + "}\n";
}
} }
} else { } else {
res += "{"; res += "{";
for (DimSize_t j = 0; j < dims()[0]; ++j) { for (DimSize_t j = 0; j < dims()[0]; ++j) {
res += " " + ptrToString(mDataType, mImpl->hostPtr(mImplOffset), j) + ((j < dims()[0]-1) ? "," : " "); res += " " +
ptrToString(mDataType, mImpl->hostPtr(mImplOffset), j) +
((j < dims()[0] - 1) ? "," : " ");
} }
} }
res += "}"; res += "}";
return res; return res;
} }
Aidge::Tensor Aidge::Tensor::extract(const std::vector<std::size_t>& fixedCoord) const { Aidge::Tensor Aidge::Tensor::extract(
const std::vector<std::size_t>& fixedCoord) const {
AIDGE_ASSERT(isContiguous(), "Tensor must be contiguous"); AIDGE_ASSERT(isContiguous(), "Tensor must be contiguous");
AIDGE_ASSERT(fixedCoord.size() <= mDims.size(), "Number of coordinates is higher than number of dimensions"); AIDGE_ASSERT(fixedCoord.size() <= mDims.size(),
"Number of coordinates is higher than number of dimensions");
Tensor subTensor(mDataType); Tensor subTensor(mDataType);
subTensor.resize(std::vector<size_t>(mDims.cbegin() + fixedCoord.size(), mDims.cend()), subTensor.resize(
std::vector<size_t>(mStrides.cbegin() + fixedCoord.size(), mStrides.cend())); std::vector<size_t>(mDims.cbegin() + fixedCoord.size(), mDims.cend()),
std::vector<size_t>(mStrides.cbegin() + fixedCoord.size(),
mStrides.cend()));
subTensor.setBackend(mImpl->backend(), mImpl->device().second); subTensor.setBackend(mImpl->backend(), mImpl->device().second);
subTensor.setImpl(mImpl, mImplOffset + getStorageIdx(fixedCoord)); subTensor.setImpl(mImpl, mImplOffset + getStorageIdx(fixedCoord));
return subTensor; return subTensor;
} }
Aidge::Tensor Aidge::Tensor::extract(const std::vector<std::size_t>& startCoord, const std::vector<std::size_t>& dims) const { Aidge::Tensor Aidge::Tensor::extract(
const std::vector<std::size_t>& startCoord,
const std::vector<std::size_t>& dims) const {
AIDGE_ASSERT(isContiguous(), "Tensor must be contiguous"); AIDGE_ASSERT(isContiguous(), "Tensor must be contiguous");
AIDGE_ASSERT(startCoord.size() == mDims.size(), "Coordinates does not match number of dimensions"); AIDGE_ASSERT(startCoord.size() == mDims.size(),
"Coordinates does not match number of dimensions");
Tensor subTensor(mDataType); Tensor subTensor(mDataType);
subTensor.resize(dims, mStrides); subTensor.resize(dims, mStrides);
...@@ -224,7 +247,8 @@ void Aidge::Tensor::makeContiguous() { ...@@ -224,7 +247,8 @@ void Aidge::Tensor::makeContiguous() {
// Block so that mImpl ref count is 1 for resize() // Block so that mImpl ref count is 1 for resize()
{ {
// Create a new storage that will be contiguous // Create a new storage that will be contiguous
std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), mDataType})(mImpl->device().second, mDims); std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create(
{mImpl->backend(), mDataType})(mImpl->device().second, mDims);
// Copy elements from old to new storage // Copy elements from old to new storage
std::size_t idx = 0; std::size_t idx = 0;
while (idx < mSize) { while (idx < mSize) {
...@@ -233,13 +257,14 @@ void Aidge::Tensor::makeContiguous() { ...@@ -233,13 +257,14 @@ void Aidge::Tensor::makeContiguous() {
// Determine the size of the contiguous chunk // Determine the size of the contiguous chunk
std::size_t copySize = 1; std::size_t copySize = 1;
while (idx + copySize < mSize && while (idx + copySize < mSize &&
getStorageIdx(getCoord(idx + copySize)) == storageIdx + copySize) getStorageIdx(getCoord(idx + copySize)) ==
{ storageIdx + copySize) {
++copySize; ++copySize;
} }
// Perform a single copy for the contiguous chunk // Perform a single copy for the contiguous chunk
newImpl->copy(mImpl->rawPtr(mImplOffset + storageIdx), copySize, idx); newImpl->copy(mImpl->rawPtr(mImplOffset + storageIdx), copySize,
idx);
// Move to the next index after the contiguous chunk // Move to the next index after the contiguous chunk
idx += copySize; idx += copySize;
...@@ -267,8 +292,10 @@ void Aidge::Tensor::copyCast(const Tensor& src) { ...@@ -267,8 +292,10 @@ void Aidge::Tensor::copyCast(const Tensor& src) {
} }
resize(src.dims()); resize(src.dims());
AIDGE_ASSERT(src.getImpl()->device() == getImpl()->device(), "cannot copy-cast from a different backend/device"); AIDGE_ASSERT(src.getImpl()->device() == getImpl()->device(),
getImpl()->copyCast(src.getImpl()->rawPtr(src.mImplOffset), src.dataType(), src.size(), mImplOffset); "cannot copy-cast from a different backend/device");
getImpl()->copyCast(src.getImpl()->rawPtr(src.mImplOffset), src.dataType(),
src.size(), mImplOffset);
} }
void Aidge::Tensor::copyFrom(const Tensor& src) { void Aidge::Tensor::copyFrom(const Tensor& src) {
...@@ -286,16 +313,20 @@ void Aidge::Tensor::copyFrom(const Tensor& src) { ...@@ -286,16 +313,20 @@ void Aidge::Tensor::copyFrom(const Tensor& src) {
} }
resize(src.dims()); resize(src.dims());
AIDGE_ASSERT(src.dataType() == dataType(), "cannot copy from a different data type"); AIDGE_ASSERT(src.dataType() == dataType(),
getImpl()->copyFrom(*(src.getImpl()), src.size(), src.mImplOffset, mImplOffset); "cannot copy from a different data type");
getImpl()->copyFrom(*(src.getImpl()), src.size(), src.mImplOffset,
mImplOffset);
} }
void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& movedSrcPtr) { void Aidge::Tensor::copyCastFrom(const Tensor& src,
std::shared_ptr<Tensor>& movedSrcPtr) {
if (&src == this) { if (&src == this) {
return; return;
} }
AIDGE_ASSERT(src.isContiguous(), "cannot copy-cast from non-contiguous tensor"); AIDGE_ASSERT(src.isContiguous(),
"cannot copy-cast from non-contiguous tensor");
// Current Tensor has necessarily a data type, but may not have backend // Current Tensor has necessarily a data type, but may not have backend
if (!getImpl()) { if (!getImpl()) {
...@@ -308,29 +339,33 @@ void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& mov ...@@ -308,29 +339,33 @@ void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& mov
if (dataType() != src.dataType()) { if (dataType() != src.dataType()) {
// First move data to the target device (only if needed) // First move data to the target device (only if needed)
const auto device = getImpl()->device(); const auto device = getImpl()->device();
const Tensor& movedSrc = src.refFrom(movedSrcPtr, device.first, device.second); const Tensor& movedSrc =
src.refFrom(movedSrcPtr, device.first, device.second);
// Second, copy-cast data (necessary) // Second, copy-cast data (necessary)
getImpl()->copyCast(movedSrc.getImpl()->rawPtr(movedSrc.mImplOffset), movedSrc.dataType(), movedSrc.size(), mImplOffset); getImpl()->copyCast(movedSrc.getImpl()->rawPtr(movedSrc.mImplOffset),
} movedSrc.dataType(), movedSrc.size(), mImplOffset);
else { } else {
// Directly copy, no conversion necessary // Directly copy, no conversion necessary
// Avoid making a double copy if both data type and device are the same // Avoid making a double copy if both data type and device are the same
getImpl()->copyFrom(*(src.getImpl()), src.size(), src.mImplOffset, mImplOffset); getImpl()->copyFrom(*(src.getImpl()), src.size(), src.mImplOffset,
mImplOffset);
} }
} }
Aidge::Tensor& Aidge::Tensor::refContiguous(std::shared_ptr<Tensor>& fallback) { Aidge::Tensor& Aidge::Tensor::refContiguous(std::shared_ptr<Tensor>& fallback) {
// Scott Meyers' solution to avoid code duplication // Scott Meyers' solution to avoid code duplication
return const_cast<Tensor&>(static_cast<const Tensor&>(*this).refContiguous(fallback)); return const_cast<Tensor&>(
static_cast<const Tensor&>(*this).refContiguous(fallback));
} }
const Aidge::Tensor& Aidge::Tensor::refContiguous(std::shared_ptr<Tensor>& fallback) const { const Aidge::Tensor& Aidge::Tensor::refContiguous(
AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot refCast() it"); std::shared_ptr<Tensor>& fallback) const {
AIDGE_ASSERT(getImpl(),
"no backend was set for tensor, cannot refCast() it");
if (isContiguous()) { if (isContiguous()) {
return *this; return *this;
} } else {
else {
if (this != fallback.get()) { if (this != fallback.get()) {
// Shallow copy to fallback // Shallow copy to fallback
*fallback = *this; *fallback = *this;
...@@ -342,96 +377,117 @@ const Aidge::Tensor& Aidge::Tensor::refContiguous(std::shared_ptr<Tensor>& fallb ...@@ -342,96 +377,117 @@ const Aidge::Tensor& Aidge::Tensor::refContiguous(std::shared_ptr<Tensor>& fallb
} }
} }
Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt) { Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback,
const Aidge::DataType& dt) {
// Scott Meyers' solution to avoid code duplication // Scott Meyers' solution to avoid code duplication
return const_cast<Tensor&>(static_cast<const Tensor&>(*this).refCast(fallback, dt)); return const_cast<Tensor&>(
static_cast<const Tensor&>(*this).refCast(fallback, dt));
} }
const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt) const { const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback,
AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot refCast() it"); const Aidge::DataType& dt) const {
AIDGE_ASSERT(getImpl(),
"no backend was set for tensor, cannot refCast() it");
if (dt == dataType()) { if (dt == dataType()) {
return *this; return *this;
} } else {
else {
if (this == fallback.get()) { if (this == fallback.get()) {
// if refFrom() was called before, just change the type // if refFrom() was called before, just change the type
fallback->setDataType(dt); fallback->setDataType(dt);
} } else {
else { AIDGE_ASSERT(isContiguous(),
AIDGE_ASSERT(isContiguous(), "cannot refCast non-contiguous tensor"); "cannot refCast non-contiguous tensor");
if (!fallback) { if (!fallback) {
fallback = std::make_shared<Tensor>(dt); fallback = std::make_shared<Tensor>(dt);
} } else {
else { fallback->setDataType(
fallback->setDataType(dt, false); // don't keep previous data (no copy) dt, false); // don't keep previous data (no copy)
} }
const auto device = getImpl()->device(); const auto device = getImpl()->device();
fallback->setBackend(device.first, device.second, false); // don't keep previous data (no copy) fallback->setBackend(device.first, device.second,
false); // don't keep previous data (no copy)
fallback->resize(dims()); fallback->resize(dims());
fallback->getImpl()->copyCast(getImpl()->rawPtr(mImplOffset), dataType(), size(), fallback->mImplOffset); fallback->getImpl()->copyCast(getImpl()->rawPtr(mImplOffset),
dataType(), size(),
fallback->mImplOffset);
} }
return *fallback; return *fallback;
} }
} }
Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, const std::string &backend, DeviceIdx_t device) { Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback,
const std::string& backend,
DeviceIdx_t device) {
// Scott Meyers' solution to avoid code duplication // Scott Meyers' solution to avoid code duplication
return const_cast<Tensor&>(static_cast<const Tensor&>(*this).refFrom(fallback, backend, device)); return const_cast<Tensor&>(
static_cast<const Tensor&>(*this).refFrom(fallback, backend, device));
} }
const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, const std::string &backend, DeviceIdx_t device) const { const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback,
AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot refFrom() it"); const std::string& backend,
DeviceIdx_t device) const {
AIDGE_ASSERT(getImpl(),
"no backend was set for tensor, cannot refFrom() it");
if (std::make_pair(backend, device) == getImpl()->device()) { if (std::make_pair(backend, device) == getImpl()->device()) {
return *this; return *this;
} } else {
else {
if (this == fallback.get()) { if (this == fallback.get()) {
// if refCast() was called before, just change the backend // if refCast() was called before, just change the backend
fallback->setBackend(backend, device); fallback->setBackend(backend, device);
} } else {
else { AIDGE_ASSERT(isContiguous(),
AIDGE_ASSERT(isContiguous(), "cannot refFrom non-contiguous tensor"); "cannot refFrom non-contiguous tensor");
if (!fallback) { if (!fallback) {
fallback = std::make_shared<Tensor>(dataType()); fallback = std::make_shared<Tensor>(dataType());
} } else {
else { fallback->setDataType(
fallback->setDataType(dataType(), false); // don't keep previous data (no copy) dataType(), false); // don't keep previous data (no copy)
} }
fallback->setBackend(backend, device, false); // don't keep previous data (no copy) fallback->setBackend(backend, device,
false); // don't keep previous data (no copy)
fallback->resize(dims()); fallback->resize(dims());
fallback->getImpl()->copyFrom(*getImpl(), size(), mImplOffset, fallback->mImplOffset); fallback->getImpl()->copyFrom(*getImpl(), size(), mImplOffset,
fallback->mImplOffset);
} }
return *fallback; return *fallback;
} }
} }
Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device) { Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback,
const Aidge::DataType& dt,
const std::string& backend,
DeviceIdx_t device) {
// Scott Meyers' solution to avoid code duplication // Scott Meyers' solution to avoid code duplication
return const_cast<Tensor&>(static_cast<const Tensor&>(*this).ref(fallback, dt, backend, device)); return const_cast<Tensor&>(
static_cast<const Tensor&>(*this).ref(fallback, dt, backend, device));
} }
const Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device) const { const Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback,
const Aidge::DataType& dt,
const std::string& backend,
DeviceIdx_t device) const {
AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot ref() it"); AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot ref() it");
if (dt == dataType() && std::make_pair(backend, device) == getImpl()->device()) { if (dt == dataType() &&
std::make_pair(backend, device) == getImpl()->device()) {
return *this; return *this;
} } else {
else {
// Change fallback type, backend & device, without any data copy // Change fallback type, backend & device, without any data copy
if (!fallback) { if (!fallback) {
fallback = std::make_shared<Tensor>(dt); fallback = std::make_shared<Tensor>(dt);
} } else {
else { fallback->setDataType(dt,
fallback->setDataType(dt, false); // don't keep previous data (no copy) false); // don't keep previous data (no copy)
} }
fallback->setBackend(backend, device, false); // don't keep previous data (no copy) fallback->setBackend(backend, device,
false); // don't keep previous data (no copy)
fallback->resize(dims()); fallback->resize(dims());
return *fallback; return *fallback;
} }
...@@ -439,7 +495,7 @@ const Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const ...@@ -439,7 +495,7 @@ const Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const
std::set<std::string> Aidge::Tensor::getAvailableBackends() { std::set<std::string> Aidge::Tensor::getAvailableBackends() {
std::set<std::string> backendsList; std::set<std::string> backendsList;
for(const auto& tupleKey : Registrar<Tensor>::getKeys()) for (const auto& tupleKey : Registrar<Tensor>::getKeys())
backendsList.insert(std::get<0>(tupleKey)); backendsList.insert(std::get<0>(tupleKey));
return backendsList; return backendsList;
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment