Skip to content
Snippets Groups Projects
Commit 443fc8d6 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Corrected wrong behavior when cascading refCast and refFrom

parent 3f16a3f7
No related branches found
No related tags found
No related merge requests found
......@@ -723,8 +723,8 @@ class Tensor : public Data,
* Return a reference to a Tensor on desired data type and backend/device:
* - itself, if already with the right characteristics;
* - the provided Tensor, overwritten with the copy-casted data.
* If required, fallback is always allocated on current (destination)
* Tensor's device.
* If required, fallback is always allocated on desired (destination)
* device.
* @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
* The shared_ptr does not need to be initialized. No new memory allocation
* will occur if fallback has already been allocated with the right
......@@ -735,7 +735,7 @@ class Tensor : public Data,
* @return Reference to either itself or to fallback.
*/
Tensor& refCastFrom(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, int device = 0) {
// First refFrom, to ensure that fallback, if required, is on current Tensor's device
// First refFrom, to ensure that fallback, if required, is also on desired device
return refFrom(fallback, backend, device).refCast(fallback, dt);
}
......
......@@ -52,17 +52,23 @@ const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, c
return *this;
}
else {
if (!fallback) {
fallback = std::make_shared<Tensor>(dt);
if (this == fallback.get()) {
// if refFrom() was called before, just change the type
fallback->setDataType(dt);
}
else {
fallback->setDataType(dt, false); // don't keep previous data (no copy)
}
if (!fallback) {
fallback = std::make_shared<Tensor>(dt);
}
else {
fallback->setDataType(dt, false); // don't keep previous data (no copy)
}
const auto device = getImpl()->device();
fallback->setBackend(device.first, device.second, false); // don't keep previous data (no copy)
fallback->resize(dims());
fallback->getImpl()->copyCast(getImpl()->rawPtr(), size(), dataType());
const auto device = getImpl()->device();
fallback->setBackend(device.first, device.second, false); // don't keep previous data (no copy)
fallback->resize(dims());
fallback->getImpl()->copyCast(getImpl()->rawPtr(), size(), dataType());
}
return *fallback;
}
}
......@@ -79,16 +85,22 @@ const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, c
return *this;
}
else {
if (!fallback) {
fallback = std::make_shared<Tensor>(dataType());
if (this == fallback.get()) {
// if refCast() was called before, just change the backend
fallback->setBackend(backend, device);
}
else {
fallback->setDataType(dataType(), false); // don't keep previous data (no copy)
}
if (!fallback) {
fallback = std::make_shared<Tensor>(dataType());
}
else {
fallback->setDataType(dataType(), false); // don't keep previous data (no copy)
}
fallback->setBackend(backend, device, false); // don't keep previous data (no copy)
fallback->resize(dims());
fallback->getImpl()->copyFrom(*getImpl(), size());
fallback->setBackend(backend, device, false); // don't keep previous data (no copy)
fallback->resize(dims());
fallback->getImpl()->copyFrom(*getImpl(), size());
}
return *fallback;
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment