diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 7195f0b204c78a389541eb21409377a1eab3d24f..9a9f26bb2937bde36db3ec0bb9ac5b5feb0b542f 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -485,16 +485,42 @@ class Tensor : public Data, std::string toString() const { + // TODO: move lambda elsewhere? + auto ptrToString = [](DataType dt, void* ptr, size_t idx) { + switch (dt) { + case DataType::Float64: + return std::to_string(static_cast<double*>(ptr)[idx]); + case DataType::Float32: + return std::to_string(static_cast<float*>(ptr)[idx]); + case DataType::Float16: + return std::to_string(static_cast<half_float::half*>(ptr)[idx]); + case DataType::Int8: + return std::to_string(static_cast<int8_t*>(ptr)[idx]); + case DataType::Int16: + return std::to_string(static_cast<int16_t*>(ptr)[idx]); + case DataType::Int32: + return std::to_string(static_cast<int32_t*>(ptr)[idx]); + case DataType::Int64: + return std::to_string(static_cast<int64_t*>(ptr)[idx]); + case DataType::UInt8: + return std::to_string(static_cast<uint8_t*>(ptr)[idx]); + case DataType::UInt16: + return std::to_string(static_cast<uint16_t*>(ptr)[idx]); + case DataType::UInt32: + return std::to_string(static_cast<uint32_t*>(ptr)[idx]); + case DataType::UInt64: + return std::to_string(static_cast<uint64_t*>(ptr)[idx]); + default: + AIDGE_ASSERT(true, "unsupported type to convert to string"); + } + }; + if (dims().empty()) { return "{}"; } std::string res; std::size_t dim = 0; std::size_t counter = 0; if (nbDims()>=2) { - std::size_t *dimVals = new std::size_t[nbDims()]; - for (std::size_t i = 0; i < nbDims(); ++i) { - dimVals[i] = 0; - } - // std::vector<std::size_t> dimVals = std::vector<std::size_t>(nbDims(), 0); + std::vector<std::size_t> dimVals(nbDims(), 0); res += "{\n"; while (counter < mSize) { std::string spaceString = std::string((dim+1)<<1,' '); @@ -514,31 +540,9 @@ class Tensor : public Data, for (; dimVals[dim] < static_cast<std::size_t>(dims()[dim]); ++dimVals[dim]) { res += spaceString + "{"; for (DimSize_t j = 0; j < dims()[dim + 1] - 1; ++j) { - switch (mDataType) - { - case DataType::Int32: - res += " " + std::to_string(static_cast<int *>(mImpl->rawPtr())[counter++]) + ","; - break; - case DataType::Float64: - res += " " + std::to_string(static_cast<double *>(mImpl->rawPtr())[counter++]) + ","; - break; - default: - res += " " + std::to_string(static_cast<float *>(mImpl->rawPtr())[counter++]) + ","; - break; - } - } - switch (mDataType) - { - case DataType::Int32: - res += " " + std::to_string(static_cast<int *>(mImpl->rawPtr())[counter++]) + "}"; - break; - case DataType::Float64: - res += " " + std::to_string(static_cast<double *>(mImpl->rawPtr())[counter++]) + "}"; - break; - default: - res += " " + std::to_string(static_cast<float *>(mImpl->rawPtr())[counter++]) + "}"; - break; + res += " " + ptrToString(mDataType, mImpl->rawPtr(), counter++) + ","; } + res += " " + ptrToString(mDataType, mImpl->rawPtr(), counter++) + "}"; if (dimVals[dim] < static_cast<std::size_t>(dims()[dim] - 1)) { res += ","; } @@ -551,7 +555,6 @@ class Tensor : public Data, dimVals[dim]++; } } - delete[] dimVals; for(int i = static_cast<int>(dim); i > 0; --i) { res += std::string((dim+1)<<1,' ') + "}\n"; @@ -559,18 +562,7 @@ class Tensor : public Data, } else { res += "{"; for (DimSize_t j = 0; j < dims()[0]; ++j) { - switch (mDataType) - { - case DataType::Int32: - res += " " + std::to_string(static_cast<int *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : ""); - break; - case DataType::Float64: - res += " " + std::to_string(static_cast<double *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : ""); - break; - default: - res += " " + std::to_string(static_cast<float *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : ""); - break; - } + res += " " + ptrToString(mDataType, mImpl->rawPtr(), j) + ((j < dims()[0]-1) ? "," : ""); } } res += "}"; @@ -629,24 +621,30 @@ class Tensor : public Data, /** * Copy-cast data from a Tensor. * @param src Source tensor to copy-cast from. - * @param convertedSrc shared_ptr to an indermediate Tensor that will - * contain the converted data if a conversion should occur. Any data already - * present will be overwritten. No new memory allocation will occur if - * convertedSrc has already been allocated with the right type/size/device. + * @param movedSrc shared_ptr to an indermediate Tensor that will + * contain the moved data if a device change should occur AND a type + * conversion is necessary (otherwise it remains unused). + * Any data already present will be overwritten. No new memory allocation + * will occur if movedSrc has already been allocated with the right + * type/size/device. + * If required, memory is always allocated on current (destination) + * Tensor's device. */ - void copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& convertedSrc); + void copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& movedSrc); /** * Copy-cast data from a Tensor. - * In case of a conversion, an intermediate buffer will be allocated and - * deallocated each time. + * In case of both a device change AND a data type conversion, an + * intermediate buffer on will be allocated and deallocated each time. + * If required, buffer's memory is always allocated on current (destination) + * Tensor's device. * @param src Source tensor to copy-cast from. */ void copyCastFrom(const Tensor& src) { - // Internal buffers will be allocated and deallocated at each call - // (if they are needed) - std::shared_ptr<Tensor> convertedSrc; - copyCastFrom(src, convertedSrc); + // Internal buffer will be allocated and deallocated at each call + // (only if needed) + std::shared_ptr<Tensor> movedSrc; + copyCastFrom(src, movedSrc); } /** diff --git a/include/aidge/operator/Convert.hpp b/include/aidge/operator/Convert.hpp index cb54ffbf7487396c0fd9477265f9274f150aa39e..aa35d6812d95ce67af0a639e195846f4baf0b54a 100644 --- a/include/aidge/operator/Convert.hpp +++ b/include/aidge/operator/Convert.hpp @@ -67,9 +67,11 @@ public: } private: - /// @brief Store the data to the right type on input device - /// Required for any type conversion. - std::shared_ptr<Tensor> mConvertedInput; + /// @brief Store the input data to the output device, before type conversion. + /// Used only when there is both a change of device AND of data type. + /// Otherwise, data is either directly copied from the other device or + /// casted on the same device (requiring a single copy). + std::shared_ptr<Tensor> mMovedInput; }; inline std::shared_ptr<Node> Convert(const std::string& name = "") { diff --git a/src/backend/TensorImpl.cpp b/src/backend/TensorImpl.cpp index 282f1222e8e944ab5afc2e5545e67be3b3b0f331..3982ee1fed9c9198b539bf9a28edd461992b791f 100644 --- a/src/backend/TensorImpl.cpp +++ b/src/backend/TensorImpl.cpp @@ -15,7 +15,7 @@ #include "aidge/utils/ErrorHandling.hpp" void Aidge::TensorImpl::copyFrom(const TensorImpl& srcImpl, NbElts_t length) { - if (srcImpl == *this) { + if (&srcImpl == this) { return; } diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp index 1f8257a706607d626ecec33a094352842445a4c0..a90b6b31d4479adf455f37f0741acb059c16abdf 100644 --- a/src/data/Tensor.cpp +++ b/src/data/Tensor.cpp @@ -13,13 +13,31 @@ #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" -void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& convertedSrcPtr) { - if (src == *this) { +void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& movedSrcPtr) { + if (&src == this) { return; } - const Tensor& convertedSrc = src.refCast(convertedSrcPtr, dataType()); - getImpl()->copyFrom(*(convertedSrc.getImpl()), convertedSrc.size()); + // Current Tensor has necessarily a data type, but may not have backend + if (!getImpl()) { + // If no backend was set for the current tensor, use the same as src + const auto deviceSrc = src.getImpl()->device(); + setBackend(deviceSrc.first, deviceSrc.second); + resize(src.dims()); + } + + if (dataType() != src.dataType()) { + // First move data to the target device (only if needed) + const auto device = getImpl()->device(); + const Tensor& movedSrc = src.ref(movedSrcPtr, device.first, device.second); + // Second, copy-cast data (necessary) + getImpl()->copyCast(movedSrc.getImpl()->rawPtr(), movedSrc.size(), movedSrc.dataType()); + } + else { + // Directly copy, no conversion necessary + // Avoid making a double copy if both data type and device are the same + getImpl()->copyFrom(*(src.getImpl()), src.size()); + } } Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt) { @@ -28,6 +46,8 @@ Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, const A } const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt) const { + AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot refCast() it"); + if (dt == dataType()) { return *this; } @@ -53,6 +73,8 @@ Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const std:: } const Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device) const { + AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot ref() it"); + if (std::make_pair(backend, device) == getImpl()->device()) { return *this; } diff --git a/src/operator/Convert.cpp b/src/operator/Convert.cpp index 14769add658ef598df2d47584042d99a430e3c1a..fcdcf4013f14f0673c0e9a53dbdb6d30543540f7 100644 --- a/src/operator/Convert.cpp +++ b/src/operator/Convert.cpp @@ -17,7 +17,7 @@ void Aidge::Convert_Op::forward() { mImpl->forward(); } else { - mOutputs[0]->copyCastFrom(*(mInputs[0]), mConvertedInput); + mOutputs[0]->copyCastFrom(*(mInputs[0]), mMovedInput); } runHooks();