diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 6c51af6ab188b16a5eea83a6a144c7527075f2fc..3d40ca644403ff64e0285fbaafc6ea8c100eb704 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -723,8 +723,8 @@ class Tensor : public Data, * Return a reference to a Tensor on desired data type and backend/device: * - itself, if already with the right characteristics; * - the provided Tensor, overwritten with the copy-casted data. - * If required, fallback is always allocated on current (destination) - * Tensor's device. + * If required, fallback is always allocated on desired (destination) + * device. * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary. * The shared_ptr does not need to be initialized. No new memory allocation * will occur if fallback has already been allocated with the right @@ -735,7 +735,7 @@ class Tensor : public Data, * @return Reference to either itself or to fallback. */ Tensor& refCastFrom(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, int device = 0) { - // First refFrom, to ensure that fallback, if required, is on current Tensor's device + // First refFrom, to ensure that fallback, if required, is also on desired device return refFrom(fallback, backend, device).refCast(fallback, dt); } diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp index b14fe994d03f8396cf10503743dfb89d2bf8ccec..15e6782a08a9d080688d95be9f16f0d66aec3fbe 100644 --- a/src/data/Tensor.cpp +++ b/src/data/Tensor.cpp @@ -52,17 +52,23 @@ const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, c return *this; } else { - if (!fallback) { - fallback = std::make_shared<Tensor>(dt); + if (this == fallback.get()) { + // if refFrom() was called before, just change the type + fallback->setDataType(dt); } else { - fallback->setDataType(dt, false); // don't keep previous data (no copy) - } + if (!fallback) { + fallback = std::make_shared<Tensor>(dt); + } + else { + fallback->setDataType(dt, false); // don't keep previous data (no copy) + } - const auto device = getImpl()->device(); - fallback->setBackend(device.first, device.second, false); // don't keep previous data (no copy) - fallback->resize(dims()); - fallback->getImpl()->copyCast(getImpl()->rawPtr(), size(), dataType()); + const auto device = getImpl()->device(); + fallback->setBackend(device.first, device.second, false); // don't keep previous data (no copy) + fallback->resize(dims()); + fallback->getImpl()->copyCast(getImpl()->rawPtr(), size(), dataType()); + } return *fallback; } } @@ -79,16 +85,22 @@ const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, c return *this; } else { - if (!fallback) { - fallback = std::make_shared<Tensor>(dataType()); + if (this == fallback.get()) { + // if refCast() was called before, just change the backend + fallback->setBackend(backend, device); } else { - fallback->setDataType(dataType(), false); // don't keep previous data (no copy) - } + if (!fallback) { + fallback = std::make_shared<Tensor>(dataType()); + } + else { + fallback->setDataType(dataType(), false); // don't keep previous data (no copy) + } - fallback->setBackend(backend, device, false); // don't keep previous data (no copy) - fallback->resize(dims()); - fallback->getImpl()->copyFrom(*getImpl(), size()); + fallback->setBackend(backend, device, false); // don't keep previous data (no copy) + fallback->resize(dims()); + fallback->getImpl()->copyFrom(*getImpl(), size()); + } return *fallback; } }