Skip to content
Snippets Groups Projects
Commit 42056687 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed issues, clarified conversion specs

parent abfa1d87
No related branches found
No related tags found
No related merge requests found
......@@ -485,16 +485,42 @@ class Tensor : public Data,
std::string toString() const {
// TODO: move lambda elsewhere?
auto ptrToString = [](DataType dt, void* ptr, size_t idx) {
switch (dt) {
case DataType::Float64:
return std::to_string(static_cast<double*>(ptr)[idx]);
case DataType::Float32:
return std::to_string(static_cast<float*>(ptr)[idx]);
case DataType::Float16:
return std::to_string(static_cast<half_float::half*>(ptr)[idx]);
case DataType::Int8:
return std::to_string(static_cast<int8_t*>(ptr)[idx]);
case DataType::Int16:
return std::to_string(static_cast<int16_t*>(ptr)[idx]);
case DataType::Int32:
return std::to_string(static_cast<int32_t*>(ptr)[idx]);
case DataType::Int64:
return std::to_string(static_cast<int64_t*>(ptr)[idx]);
case DataType::UInt8:
return std::to_string(static_cast<uint8_t*>(ptr)[idx]);
case DataType::UInt16:
return std::to_string(static_cast<uint16_t*>(ptr)[idx]);
case DataType::UInt32:
return std::to_string(static_cast<uint32_t*>(ptr)[idx]);
case DataType::UInt64:
return std::to_string(static_cast<uint64_t*>(ptr)[idx]);
default:
AIDGE_ASSERT(true, "unsupported type to convert to string");
}
};
if (dims().empty()) { return "{}"; }
std::string res;
std::size_t dim = 0;
std::size_t counter = 0;
if (nbDims()>=2) {
std::size_t *dimVals = new std::size_t[nbDims()];
for (std::size_t i = 0; i < nbDims(); ++i) {
dimVals[i] = 0;
}
// std::vector<std::size_t> dimVals = std::vector<std::size_t>(nbDims(), 0);
std::vector<std::size_t> dimVals(nbDims(), 0);
res += "{\n";
while (counter < mSize) {
std::string spaceString = std::string((dim+1)<<1,' ');
......@@ -514,31 +540,9 @@ class Tensor : public Data,
for (; dimVals[dim] < static_cast<std::size_t>(dims()[dim]); ++dimVals[dim]) {
res += spaceString + "{";
for (DimSize_t j = 0; j < dims()[dim + 1] - 1; ++j) {
switch (mDataType)
{
case DataType::Int32:
res += " " + std::to_string(static_cast<int *>(mImpl->rawPtr())[counter++]) + ",";
break;
case DataType::Float64:
res += " " + std::to_string(static_cast<double *>(mImpl->rawPtr())[counter++]) + ",";
break;
default:
res += " " + std::to_string(static_cast<float *>(mImpl->rawPtr())[counter++]) + ",";
break;
}
}
switch (mDataType)
{
case DataType::Int32:
res += " " + std::to_string(static_cast<int *>(mImpl->rawPtr())[counter++]) + "}";
break;
case DataType::Float64:
res += " " + std::to_string(static_cast<double *>(mImpl->rawPtr())[counter++]) + "}";
break;
default:
res += " " + std::to_string(static_cast<float *>(mImpl->rawPtr())[counter++]) + "}";
break;
res += " " + ptrToString(mDataType, mImpl->rawPtr(), counter++) + ",";
}
res += " " + ptrToString(mDataType, mImpl->rawPtr(), counter++) + "}";
if (dimVals[dim] < static_cast<std::size_t>(dims()[dim] - 1)) {
res += ",";
}
......@@ -551,7 +555,6 @@ class Tensor : public Data,
dimVals[dim]++;
}
}
delete[] dimVals;
for(int i = static_cast<int>(dim); i > 0; --i) {
res += std::string((dim+1)<<1,' ') + "}\n";
......@@ -559,18 +562,7 @@ class Tensor : public Data,
} else {
res += "{";
for (DimSize_t j = 0; j < dims()[0]; ++j) {
switch (mDataType)
{
case DataType::Int32:
res += " " + std::to_string(static_cast<int *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "");
break;
case DataType::Float64:
res += " " + std::to_string(static_cast<double *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "");
break;
default:
res += " " + std::to_string(static_cast<float *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "");
break;
}
res += " " + ptrToString(mDataType, mImpl->rawPtr(), j) + ((j < dims()[0]-1) ? "," : "");
}
}
res += "}";
......@@ -629,24 +621,30 @@ class Tensor : public Data,
/**
* Copy-cast data from a Tensor.
* @param src Source tensor to copy-cast from.
* @param convertedSrc shared_ptr to an indermediate Tensor that will
* contain the converted data if a conversion should occur. Any data already
* present will be overwritten. No new memory allocation will occur if
* convertedSrc has already been allocated with the right type/size/device.
* @param movedSrc shared_ptr to an indermediate Tensor that will
* contain the moved data if a device change should occur AND a type
* conversion is necessary (otherwise it remains unused).
* Any data already present will be overwritten. No new memory allocation
* will occur if movedSrc has already been allocated with the right
* type/size/device.
* If required, memory is always allocated on current (destination)
* Tensor's device.
*/
void copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& convertedSrc);
void copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& movedSrc);
/**
* Copy-cast data from a Tensor.
* In case of a conversion, an intermediate buffer will be allocated and
* deallocated each time.
* In case of both a device change AND a data type conversion, an
* intermediate buffer on will be allocated and deallocated each time.
* If required, buffer's memory is always allocated on current (destination)
* Tensor's device.
* @param src Source tensor to copy-cast from.
*/
void copyCastFrom(const Tensor& src) {
// Internal buffers will be allocated and deallocated at each call
// (if they are needed)
std::shared_ptr<Tensor> convertedSrc;
copyCastFrom(src, convertedSrc);
// Internal buffer will be allocated and deallocated at each call
// (only if needed)
std::shared_ptr<Tensor> movedSrc;
copyCastFrom(src, movedSrc);
}
/**
......
......@@ -67,9 +67,11 @@ public:
}
private:
/// @brief Store the data to the right type on input device
/// Required for any type conversion.
std::shared_ptr<Tensor> mConvertedInput;
/// @brief Store the input data to the output device, before type conversion.
/// Used only when there is both a change of device AND of data type.
/// Otherwise, data is either directly copied from the other device or
/// casted on the same device (requiring a single copy).
std::shared_ptr<Tensor> mMovedInput;
};
inline std::shared_ptr<Node> Convert(const std::string& name = "") {
......
......@@ -15,7 +15,7 @@
#include "aidge/utils/ErrorHandling.hpp"
void Aidge::TensorImpl::copyFrom(const TensorImpl& srcImpl, NbElts_t length) {
if (srcImpl == *this) {
if (&srcImpl == this) {
return;
}
......
......@@ -13,13 +13,31 @@
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& convertedSrcPtr) {
if (src == *this) {
void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& movedSrcPtr) {
if (&src == this) {
return;
}
const Tensor& convertedSrc = src.refCast(convertedSrcPtr, dataType());
getImpl()->copyFrom(*(convertedSrc.getImpl()), convertedSrc.size());
// Current Tensor has necessarily a data type, but may not have backend
if (!getImpl()) {
// If no backend was set for the current tensor, use the same as src
const auto deviceSrc = src.getImpl()->device();
setBackend(deviceSrc.first, deviceSrc.second);
resize(src.dims());
}
if (dataType() != src.dataType()) {
// First move data to the target device (only if needed)
const auto device = getImpl()->device();
const Tensor& movedSrc = src.ref(movedSrcPtr, device.first, device.second);
// Second, copy-cast data (necessary)
getImpl()->copyCast(movedSrc.getImpl()->rawPtr(), movedSrc.size(), movedSrc.dataType());
}
else {
// Directly copy, no conversion necessary
// Avoid making a double copy if both data type and device are the same
getImpl()->copyFrom(*(src.getImpl()), src.size());
}
}
Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt) {
......@@ -28,6 +46,8 @@ Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, const A
}
const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt) const {
AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot refCast() it");
if (dt == dataType()) {
return *this;
}
......@@ -53,6 +73,8 @@ Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const std::
}
const Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device) const {
AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot ref() it");
if (std::make_pair(backend, device) == getImpl()->device()) {
return *this;
}
......
......@@ -17,7 +17,7 @@ void Aidge::Convert_Op::forward() {
mImpl->forward();
}
else {
mOutputs[0]->copyCastFrom(*(mInputs[0]), mConvertedInput);
mOutputs[0]->copyCastFrom(*(mInputs[0]), mMovedInput);
}
runHooks();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment