From 657c7302b334ed5ab9297bd6d213b79215313a36 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Fri, 26 Jan 2024 15:10:36 +0100
Subject: [PATCH] Added Tensor::ref() methods

---
 include/aidge/data/Tensor.hpp | 37 +++++++++++++++++++++++++++++++++++
 src/data/Tensor.cpp           | 26 ++++++++++++++++++++++++
 2 files changed, 63 insertions(+)

diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp
index eda3ee34b..b43406b5b 100644
--- a/include/aidge/data/Tensor.hpp
+++ b/include/aidge/data/Tensor.hpp
@@ -781,6 +781,43 @@ class Tensor : public Data,
         return refCastFrom(fallback, targetReqs.dataType(), device.first, device.second);
     }
 
+    /**
+     * Return a reference to a Tensor on desired data type and backend/device:
+     * - itself, if already with the right characteristics;
+     * - the provided Tensor, overwritten with the right characteristics.
+     * NOTE: no data is copy-casted. If it was so in a previous refCastFrom() on
+     * the same fallback, it remains valid, otherwise, data is invalid.
+     * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
+     * The shared_ptr does not need to be initialized. No new memory allocation
+     * will occur if fallback has already been allocated with the right
+     * type/size/device.
+     * @param dt The desired data type.
+     * @param backend The desired backend.
+     * @param device The desired device.
+     * @return Reference to either itself or to fallback.
+    */
+    Tensor& ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device = 0);
+    const Tensor& ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device = 0) const;
+
+    /**
+     * Return a reference to a Tensor with same characteristics
+     * (data type, backend/device) as targetReqs Tensor:
+     * - itself, if already with the right characteristics;
+     * - the provided Tensor, overwritten with the right characteristics.
+     * NOTE: no data is copy-casted. If it was so in a previous refCastFrom() on
+     * the same fallback, it remains valid, otherwise, data is invalid.
+     * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
+     * The shared_ptr does not need to be initialized. No new memory allocation
+     * will occur if fallback has already been allocated with the right
+     * type/size/device.
+     * @param targetReqs Tensor with the desired target characteristics.
+     * @return Reference to either itself or to fallback.
+    */
+    Tensor& ref(std::shared_ptr<Tensor>& fallback, const Tensor& targetReqs) {
+        const auto& device = targetReqs.getImpl()->device();
+        return ref(fallback, targetReqs.dataType(), device.first, device.second);
+    }
+
 private:
     ///\bug not protected against overflow
     void computeSize() {
diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp
index 8099ee111..6ab096201 100644
--- a/src/data/Tensor.cpp
+++ b/src/data/Tensor.cpp
@@ -216,3 +216,29 @@ const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, c
         return *fallback;
     }
 }
+
+Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device) {
+    // Scott Meyers' solution to avoid code duplication
+    return const_cast<Tensor&>(static_cast<const Tensor&>(*this).ref(fallback, dt, backend, device));
+}
+
+const Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device) const {
+    AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot ref() it");
+
+    if (dt == dataType() && std::make_pair(backend, device) == getImpl()->device()) {
+        return *this;
+    }
+    else {
+        // Change fallback type, backend & device, without any data copy
+        if (!fallback) {
+            fallback = std::make_shared<Tensor>(dt);
+        }
+        else {
+            fallback->setDataType(dt, false); // don't keep previous data (no copy)
+        }
+
+        fallback->setBackend(backend, device, false); // don't keep previous data (no copy)
+        fallback->resize(dims());
+        return *fallback;
+    }
+}
-- 
GitLab