diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp
index 6c51af6ab188b16a5eea83a6a144c7527075f2fc..3d40ca644403ff64e0285fbaafc6ea8c100eb704 100644
--- a/include/aidge/data/Tensor.hpp
+++ b/include/aidge/data/Tensor.hpp
@@ -723,8 +723,8 @@ class Tensor : public Data,
      * Return a reference to a Tensor on desired data type and backend/device:
      * - itself, if already with the right characteristics;
      * - the provided Tensor, overwritten with the copy-casted data.
-     * If required, fallback is always allocated on current (destination) 
-     * Tensor's device.
+     * If required, fallback is always allocated on desired (destination) 
+     * device.
      * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
      * The shared_ptr does not need to be initialized. No new memory allocation
      * will occur if fallback has already been allocated with the right 
@@ -735,7 +735,7 @@ class Tensor : public Data,
      * @return Reference to either itself or to fallback.
     */
     Tensor& refCastFrom(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, int device = 0) {
-        // First refFrom, to ensure that fallback, if required, is on current Tensor's device
+        // First refFrom, to ensure that fallback, if required, is also on desired device
         return refFrom(fallback, backend, device).refCast(fallback, dt);
     }
 
diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp
index b14fe994d03f8396cf10503743dfb89d2bf8ccec..15e6782a08a9d080688d95be9f16f0d66aec3fbe 100644
--- a/src/data/Tensor.cpp
+++ b/src/data/Tensor.cpp
@@ -52,17 +52,23 @@ const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, c
         return *this;
     }
     else {
-        if (!fallback) {
-            fallback = std::make_shared<Tensor>(dt);
+        if (this == fallback.get()) {
+            // if refFrom() was called before, just change the type
+            fallback->setDataType(dt);
         }
         else {
-            fallback->setDataType(dt, false); // don't keep previous data (no copy)
-        }
+            if (!fallback) {
+                fallback = std::make_shared<Tensor>(dt);
+            }
+            else {
+                fallback->setDataType(dt, false); // don't keep previous data (no copy)
+            }
 
-        const auto device = getImpl()->device();
-        fallback->setBackend(device.first, device.second, false); // don't keep previous data (no copy)
-        fallback->resize(dims());
-        fallback->getImpl()->copyCast(getImpl()->rawPtr(), size(), dataType());
+            const auto device = getImpl()->device();
+            fallback->setBackend(device.first, device.second, false); // don't keep previous data (no copy)
+            fallback->resize(dims());
+            fallback->getImpl()->copyCast(getImpl()->rawPtr(), size(), dataType());
+        }
         return *fallback;
     }
 }
@@ -79,16 +85,22 @@ const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, c
         return *this;
     }
     else {
-        if (!fallback) {
-            fallback = std::make_shared<Tensor>(dataType());
+        if (this == fallback.get()) {
+            // if refCast() was called before, just change the backend
+            fallback->setBackend(backend, device);
         }
         else {
-            fallback->setDataType(dataType(), false); // don't keep previous data (no copy)
-        }
+            if (!fallback) {
+                fallback = std::make_shared<Tensor>(dataType());
+            }
+            else {
+                fallback->setDataType(dataType(), false); // don't keep previous data (no copy)
+            }
 
-        fallback->setBackend(backend, device, false); // don't keep previous data (no copy)
-        fallback->resize(dims());
-        fallback->getImpl()->copyFrom(*getImpl(), size());
+            fallback->setBackend(backend, device, false); // don't keep previous data (no copy)
+            fallback->resize(dims());
+            fallback->getImpl()->copyFrom(*getImpl(), size());
+        }
         return *fallback;
     }
 }