From 657c7302b334ed5ab9297bd6d213b79215313a36 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 26 Jan 2024 15:10:36 +0100 Subject: [PATCH] Added Tensor::ref() methods --- include/aidge/data/Tensor.hpp | 37 +++++++++++++++++++++++++++++++++++ src/data/Tensor.cpp | 26 ++++++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index eda3ee34b..b43406b5b 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -781,6 +781,43 @@ class Tensor : public Data, return refCastFrom(fallback, targetReqs.dataType(), device.first, device.second); } + /** + * Return a reference to a Tensor on desired data type and backend/device: + * - itself, if already with the right characteristics; + * - the provided Tensor, overwritten with the right characteristics. + * NOTE: no data is copy-casted. If it was so in a previous refCastFrom() on + * the same fallback, it remains valid, otherwise, data is invalid. + * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary. + * The shared_ptr does not need to be initialized. No new memory allocation + * will occur if fallback has already been allocated with the right + * type/size/device. + * @param dt The desired data type. + * @param backend The desired backend. + * @param device The desired device. + * @return Reference to either itself or to fallback. + */ + Tensor& ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device = 0); + const Tensor& ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device = 0) const; + + /** + * Return a reference to a Tensor with same characteristics + * (data type, backend/device) as targetReqs Tensor: + * - itself, if already with the right characteristics; + * - the provided Tensor, overwritten with the right characteristics. + * NOTE: no data is copy-casted. If it was so in a previous refCastFrom() on + * the same fallback, it remains valid, otherwise, data is invalid. + * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary. + * The shared_ptr does not need to be initialized. No new memory allocation + * will occur if fallback has already been allocated with the right + * type/size/device. + * @param targetReqs Tensor with the desired target characteristics. + * @return Reference to either itself or to fallback. + */ + Tensor& ref(std::shared_ptr<Tensor>& fallback, const Tensor& targetReqs) { + const auto& device = targetReqs.getImpl()->device(); + return ref(fallback, targetReqs.dataType(), device.first, device.second); + } + private: ///\bug not protected against overflow void computeSize() { diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp index 8099ee111..6ab096201 100644 --- a/src/data/Tensor.cpp +++ b/src/data/Tensor.cpp @@ -216,3 +216,29 @@ const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, c return *fallback; } } + +Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device) { + // Scott Meyers' solution to avoid code duplication + return const_cast<Tensor&>(static_cast<const Tensor&>(*this).ref(fallback, dt, backend, device)); +} + +const Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device) const { + AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot ref() it"); + + if (dt == dataType() && std::make_pair(backend, device) == getImpl()->device()) { + return *this; + } + else { + // Change fallback type, backend & device, without any data copy + if (!fallback) { + fallback = std::make_shared<Tensor>(dt); + } + else { + fallback->setDataType(dt, false); // don't keep previous data (no copy) + } + + fallback->setBackend(backend, device, false); // don't keep previous data (no copy) + fallback->resize(dims()); + return *fallback; + } +} -- GitLab