Skip to content
Snippets Groups Projects
Commit 42056687 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed issues, clarified conversion specs

parent abfa1d87
No related branches found
No related tags found
1 merge request!57Add Convert operator (a.k.a. Transmitter)
Pipeline #35562 failed
...@@ -485,16 +485,42 @@ class Tensor : public Data, ...@@ -485,16 +485,42 @@ class Tensor : public Data,
std::string toString() const { std::string toString() const {
// TODO: move lambda elsewhere?
auto ptrToString = [](DataType dt, void* ptr, size_t idx) {
switch (dt) {
case DataType::Float64:
return std::to_string(static_cast<double*>(ptr)[idx]);
case DataType::Float32:
return std::to_string(static_cast<float*>(ptr)[idx]);
case DataType::Float16:
return std::to_string(static_cast<half_float::half*>(ptr)[idx]);
case DataType::Int8:
return std::to_string(static_cast<int8_t*>(ptr)[idx]);
case DataType::Int16:
return std::to_string(static_cast<int16_t*>(ptr)[idx]);
case DataType::Int32:
return std::to_string(static_cast<int32_t*>(ptr)[idx]);
case DataType::Int64:
return std::to_string(static_cast<int64_t*>(ptr)[idx]);
case DataType::UInt8:
return std::to_string(static_cast<uint8_t*>(ptr)[idx]);
case DataType::UInt16:
return std::to_string(static_cast<uint16_t*>(ptr)[idx]);
case DataType::UInt32:
return std::to_string(static_cast<uint32_t*>(ptr)[idx]);
case DataType::UInt64:
return std::to_string(static_cast<uint64_t*>(ptr)[idx]);
default:
AIDGE_ASSERT(true, "unsupported type to convert to string");
}
};
if (dims().empty()) { return "{}"; } if (dims().empty()) { return "{}"; }
std::string res; std::string res;
std::size_t dim = 0; std::size_t dim = 0;
std::size_t counter = 0; std::size_t counter = 0;
if (nbDims()>=2) { if (nbDims()>=2) {
std::size_t *dimVals = new std::size_t[nbDims()]; std::vector<std::size_t> dimVals(nbDims(), 0);
for (std::size_t i = 0; i < nbDims(); ++i) {
dimVals[i] = 0;
}
// std::vector<std::size_t> dimVals = std::vector<std::size_t>(nbDims(), 0);
res += "{\n"; res += "{\n";
while (counter < mSize) { while (counter < mSize) {
std::string spaceString = std::string((dim+1)<<1,' '); std::string spaceString = std::string((dim+1)<<1,' ');
...@@ -514,31 +540,9 @@ class Tensor : public Data, ...@@ -514,31 +540,9 @@ class Tensor : public Data,
for (; dimVals[dim] < static_cast<std::size_t>(dims()[dim]); ++dimVals[dim]) { for (; dimVals[dim] < static_cast<std::size_t>(dims()[dim]); ++dimVals[dim]) {
res += spaceString + "{"; res += spaceString + "{";
for (DimSize_t j = 0; j < dims()[dim + 1] - 1; ++j) { for (DimSize_t j = 0; j < dims()[dim + 1] - 1; ++j) {
switch (mDataType) res += " " + ptrToString(mDataType, mImpl->rawPtr(), counter++) + ",";
{
case DataType::Int32:
res += " " + std::to_string(static_cast<int *>(mImpl->rawPtr())[counter++]) + ",";
break;
case DataType::Float64:
res += " " + std::to_string(static_cast<double *>(mImpl->rawPtr())[counter++]) + ",";
break;
default:
res += " " + std::to_string(static_cast<float *>(mImpl->rawPtr())[counter++]) + ",";
break;
}
}
switch (mDataType)
{
case DataType::Int32:
res += " " + std::to_string(static_cast<int *>(mImpl->rawPtr())[counter++]) + "}";
break;
case DataType::Float64:
res += " " + std::to_string(static_cast<double *>(mImpl->rawPtr())[counter++]) + "}";
break;
default:
res += " " + std::to_string(static_cast<float *>(mImpl->rawPtr())[counter++]) + "}";
break;
} }
res += " " + ptrToString(mDataType, mImpl->rawPtr(), counter++) + "}";
if (dimVals[dim] < static_cast<std::size_t>(dims()[dim] - 1)) { if (dimVals[dim] < static_cast<std::size_t>(dims()[dim] - 1)) {
res += ","; res += ",";
} }
...@@ -551,7 +555,6 @@ class Tensor : public Data, ...@@ -551,7 +555,6 @@ class Tensor : public Data,
dimVals[dim]++; dimVals[dim]++;
} }
} }
delete[] dimVals;
for(int i = static_cast<int>(dim); i > 0; --i) { for(int i = static_cast<int>(dim); i > 0; --i) {
res += std::string((dim+1)<<1,' ') + "}\n"; res += std::string((dim+1)<<1,' ') + "}\n";
...@@ -559,18 +562,7 @@ class Tensor : public Data, ...@@ -559,18 +562,7 @@ class Tensor : public Data,
} else { } else {
res += "{"; res += "{";
for (DimSize_t j = 0; j < dims()[0]; ++j) { for (DimSize_t j = 0; j < dims()[0]; ++j) {
switch (mDataType) res += " " + ptrToString(mDataType, mImpl->rawPtr(), j) + ((j < dims()[0]-1) ? "," : "");
{
case DataType::Int32:
res += " " + std::to_string(static_cast<int *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "");
break;
case DataType::Float64:
res += " " + std::to_string(static_cast<double *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "");
break;
default:
res += " " + std::to_string(static_cast<float *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "");
break;
}
} }
} }
res += "}"; res += "}";
...@@ -629,24 +621,30 @@ class Tensor : public Data, ...@@ -629,24 +621,30 @@ class Tensor : public Data,
/** /**
* Copy-cast data from a Tensor. * Copy-cast data from a Tensor.
* @param src Source tensor to copy-cast from. * @param src Source tensor to copy-cast from.
* @param convertedSrc shared_ptr to an indermediate Tensor that will * @param movedSrc shared_ptr to an indermediate Tensor that will
* contain the converted data if a conversion should occur. Any data already * contain the moved data if a device change should occur AND a type
* present will be overwritten. No new memory allocation will occur if * conversion is necessary (otherwise it remains unused).
* convertedSrc has already been allocated with the right type/size/device. * Any data already present will be overwritten. No new memory allocation
* will occur if movedSrc has already been allocated with the right
* type/size/device.
* If required, memory is always allocated on current (destination)
* Tensor's device.
*/ */
void copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& convertedSrc); void copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& movedSrc);
/** /**
* Copy-cast data from a Tensor. * Copy-cast data from a Tensor.
* In case of a conversion, an intermediate buffer will be allocated and * In case of both a device change AND a data type conversion, an
* deallocated each time. * intermediate buffer on will be allocated and deallocated each time.
* If required, buffer's memory is always allocated on current (destination)
* Tensor's device.
* @param src Source tensor to copy-cast from. * @param src Source tensor to copy-cast from.
*/ */
void copyCastFrom(const Tensor& src) { void copyCastFrom(const Tensor& src) {
// Internal buffers will be allocated and deallocated at each call // Internal buffer will be allocated and deallocated at each call
// (if they are needed) // (only if needed)
std::shared_ptr<Tensor> convertedSrc; std::shared_ptr<Tensor> movedSrc;
copyCastFrom(src, convertedSrc); copyCastFrom(src, movedSrc);
} }
/** /**
......
...@@ -67,9 +67,11 @@ public: ...@@ -67,9 +67,11 @@ public:
} }
private: private:
/// @brief Store the data to the right type on input device /// @brief Store the input data to the output device, before type conversion.
/// Required for any type conversion. /// Used only when there is both a change of device AND of data type.
std::shared_ptr<Tensor> mConvertedInput; /// Otherwise, data is either directly copied from the other device or
/// casted on the same device (requiring a single copy).
std::shared_ptr<Tensor> mMovedInput;
}; };
inline std::shared_ptr<Node> Convert(const std::string& name = "") { inline std::shared_ptr<Node> Convert(const std::string& name = "") {
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
void Aidge::TensorImpl::copyFrom(const TensorImpl& srcImpl, NbElts_t length) { void Aidge::TensorImpl::copyFrom(const TensorImpl& srcImpl, NbElts_t length) {
if (srcImpl == *this) { if (&srcImpl == this) {
return; return;
} }
......
...@@ -13,13 +13,31 @@ ...@@ -13,13 +13,31 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& convertedSrcPtr) { void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& movedSrcPtr) {
if (src == *this) { if (&src == this) {
return; return;
} }
const Tensor& convertedSrc = src.refCast(convertedSrcPtr, dataType()); // Current Tensor has necessarily a data type, but may not have backend
getImpl()->copyFrom(*(convertedSrc.getImpl()), convertedSrc.size()); if (!getImpl()) {
// If no backend was set for the current tensor, use the same as src
const auto deviceSrc = src.getImpl()->device();
setBackend(deviceSrc.first, deviceSrc.second);
resize(src.dims());
}
if (dataType() != src.dataType()) {
// First move data to the target device (only if needed)
const auto device = getImpl()->device();
const Tensor& movedSrc = src.ref(movedSrcPtr, device.first, device.second);
// Second, copy-cast data (necessary)
getImpl()->copyCast(movedSrc.getImpl()->rawPtr(), movedSrc.size(), movedSrc.dataType());
}
else {
// Directly copy, no conversion necessary
// Avoid making a double copy if both data type and device are the same
getImpl()->copyFrom(*(src.getImpl()), src.size());
}
} }
Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt) { Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt) {
...@@ -28,6 +46,8 @@ Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, const A ...@@ -28,6 +46,8 @@ Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, const A
} }
const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt) const { const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt) const {
AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot refCast() it");
if (dt == dataType()) { if (dt == dataType()) {
return *this; return *this;
} }
...@@ -53,6 +73,8 @@ Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const std:: ...@@ -53,6 +73,8 @@ Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const std::
} }
const Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device) const { const Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device) const {
AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot ref() it");
if (std::make_pair(backend, device) == getImpl()->device()) { if (std::make_pair(backend, device) == getImpl()->device()) {
return *this; return *this;
} }
......
...@@ -17,7 +17,7 @@ void Aidge::Convert_Op::forward() { ...@@ -17,7 +17,7 @@ void Aidge::Convert_Op::forward() {
mImpl->forward(); mImpl->forward();
} }
else { else {
mOutputs[0]->copyCastFrom(*(mInputs[0]), mConvertedInput); mOutputs[0]->copyCastFrom(*(mInputs[0]), mMovedInput);
} }
runHooks(); runHooks();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment