Skip to content
Snippets Groups Projects
Forked from Eclipse Projects / aidge / aidge_core
1832 commits behind the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Tensor.hpp 28.48 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#ifndef AIDGE_CORE_DATA_TENSOR_H_
#define AIDGE_CORE_DATA_TENSOR_H_

#include <cstring>
#include <set>
#include <memory>
#include <numeric>   // std::accumulate
#include <string>
#include <vector>

#include "aidge/backend/TensorImpl.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ArrayHelpers.hpp"

namespace Aidge {
/**
 * @brief Description for the tensor data structure.
 * @details Sets the properties of the tensor without actually containing any data.
 * Contains a pointer to an actual contiguous implementation of data.
 */
class Tensor : public Data,
               public Registrable<Tensor, std::tuple<std::string, DataType>, std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)> {
   private:
    DataType mDataType; /** enum to specify data type. */
    std::vector<DimSize_t> mDims; /** Dimensions of the tensor. */
    std::vector<DimSize_t> mStrides; /** Stride dimensions of the tensor. */
    std::shared_ptr<TensorImpl> mImpl; /** Pointer to the actual data implementation. */
    std::size_t mImplOffset = 0;
    std::shared_ptr<Tensor> mGrad; /** Pointer to the associated gradient Tensor instance. */

    // Cached data
    std::size_t mSize = 0;    /** Number of elements in the Tensor. */
    bool mContiguous = true;

   public:
    static constexpr const char *Type = "Tensor";

    /**
     * @brief Construct a new empty Tensor object.
     * @param dataType Sets the type of inserted data.
     */
    Tensor(DataType dataType = DataType::Float32)
        : Data(Type),
          mDataType(dataType)
    {
        // ctor
    }

    /**
     * @brief Construct a new Tensor object from another one (shallow copy).
     * Data memory is not copied, but shared between the new Tensor and the
     * initial one.
     *
     * @param otherTensor
     */
    Tensor(const Tensor&)            = default;
    Tensor(Tensor&&)            = default;

    /**
     * Perform a deep copy of the tensor.
    */
    Tensor clone() const {
        Tensor newTensor(*this);
        if (!newTensor.isContiguous()) {
            newTensor.makeContiguous();
        }
        else {
            std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), mDataType})(mImpl->device().second, mSize);
            newImpl->copy(mImpl->rawPtr(mImplOffset), mSize);
            newTensor.setImpl(newImpl);
        }
        return newTensor;
    }

    /**
     * @brief Construct a new Tensor object from the 1-dimension Array helper.
     * @tparam T datatype
     * @tparam SIZE_0 first array dimension.
     */
    template <typename T, std::size_t SIZE_0>
    constexpr Tensor(Array1D<T, SIZE_0> &&arr)
        : Data(Type),
          mDataType(NativeType<T>::type),
          mDims({SIZE_0}),
          mStrides({1}),
          mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0)),
          mSize(SIZE_0) {
        mImpl->copyFromHost(&arr.data[0], SIZE_0);
    }

    template <typename T, std::size_t SIZE_0>
    constexpr Tensor &operator=(Array1D<T, SIZE_0> &&arr) {
        resize({SIZE_0});
        if (!mImpl) {
            mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0);
        }
        mImpl->copyFromHost(&arr.data[0], SIZE_0, mImplOffset);
        return *this;
    }

    /**
     * @brief Construct a new Tensor object from the 2-dimensions Array helper.
     * @tparam T datatype
     * @tparam SIZE_0 first array dimension.
     * @tparam SIZE_1 second array dimension.
     */
    template <typename T, std::size_t SIZE_0, std::size_t SIZE_1>
    constexpr Tensor(Array2D<T, SIZE_0, SIZE_1> &&arr)
        : Data(Type),
          mDataType(NativeType<T>::type),
          mDims({SIZE_0, SIZE_1}),
          mStrides({SIZE_1, 1}),
          mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0 * SIZE_1)),
          mSize(SIZE_0 * SIZE_1) {
        mImpl->copyFromHost(&arr.data[0][0], SIZE_0 * SIZE_1);
    }

    template <typename T, std::size_t SIZE_0, std::size_t SIZE_1>
    constexpr Tensor &operator=(Array2D<T, SIZE_0, SIZE_1> &&arr) {
        resize({SIZE_0, SIZE_1});
        if (!mImpl) {
            mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0 * SIZE_1);
        }
        mImpl->copyFromHost(&arr.data[0][0], SIZE_0 * SIZE_1, mImplOffset);
        return *this;
    }
    /**
     * @brief Construct a new Tensor object from the 3-dimensions Array helper.
     * @tparam T datatype
     * @tparam SIZE_0 first array dimension.
     * @tparam SIZE_1 second array dimension.
     * @tparam SIZE_2 third array dimension.
     */
    template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2>
    constexpr Tensor(Array3D<T, SIZE_0, SIZE_1, SIZE_2> &&arr)
        : Data(Type),
          mDataType(NativeType<T>::type),
          mDims({SIZE_0, SIZE_1, SIZE_2}),
          mStrides({SIZE_1 * SIZE_2, SIZE_2, 1}),
          mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0 * SIZE_1 * SIZE_2)),
          mSize(SIZE_0 * SIZE_1 * SIZE_2) {
        mImpl->copyFromHost(&arr.data[0][0][0], SIZE_0 * SIZE_1 * SIZE_2);
    }

    template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2>
    constexpr Tensor &operator=(Array3D<T, SIZE_0, SIZE_1, SIZE_2> &&arr) {
        resize({SIZE_0, SIZE_1, SIZE_2});
        if (!mImpl) {
            mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0 * SIZE_1 * SIZE_2);
        }
        mImpl->copyFromHost(&arr.data[0][0][0], SIZE_0 * SIZE_1 * SIZE_2, mImplOffset);
        return *this;
    }

    /**
     * @brief Construct a new Tensor object from the 4-dimensions Array helper.
     * @tparam T datatype
     * @tparam SIZE_0 first array dimension.
     * @tparam SIZE_1 second array dimension.
     * @tparam SIZE_2 third array dimension.
     * @tparam SIZE_3 fourth array dimension.
     */
    template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2, std::size_t SIZE_3>
    constexpr Tensor(Array4D<T, SIZE_0, SIZE_1, SIZE_2, SIZE_3> &&arr)
        : Data(Type),
          mDataType(NativeType<T>::type),
          mDims({SIZE_0, SIZE_1, SIZE_2, SIZE_3}),
          mStrides({SIZE_1 * SIZE_2 * SIZE_3, SIZE_2 * SIZE_3, SIZE_3, 1}),
          mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3)),
          mSize(SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3) {
        mImpl->copyFromHost(&arr.data[0][0][0][0], SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3);
    }

    template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2, std::size_t SIZE_3>
    constexpr Tensor &operator=(Array4D<T, SIZE_0, SIZE_1, SIZE_2, SIZE_3> &&arr) {
        resize({SIZE_0, SIZE_1, SIZE_2, SIZE_3});
        if (!mImpl) {
            mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(0, SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3);
        }
        mImpl->copyFromHost(&arr.data[0][0][0][0], SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3, mImplOffset);
        return *this;
    }

    /**
     * @brief Copy dimensions, datatype and data from another Tensor.
     * If current Tensor already has an implementation, data is copied to the
     * existing implementation. Tensor backend/device remain untouched.
     * If current Tensor does not have an implementation, only a shallow copy
     * is performed and the Tensor will share data with t.
     * @param t other Tensor object.
     * @return Tensor&
     */
    Tensor &operator=(const Tensor &t) {
        resize(t.dims(), t.strides());
        setDataType(t.dataType(), false); // do not convert existing data
        if (t.hasImpl()) {
            if (hasImpl()) {
                copyFrom(t);
            }
            else {
                // Perform a shallow copy only
                setImpl(t.mImpl, t.mImplOffset);
            }
        }
        else {
            setImpl(nullptr);
        }
        return *this;
    }

    /**
     * @brief Assess data type, dimensions, backend and data are the same.
     * @param otherTensor
     */
    bool operator==(const Tensor &otherTensor) const {
        if ((!mImpl && !otherTensor.mImpl) || (dataType() != otherTensor.dataType()) ||
            (dims() != otherTensor.dims()) || (mImpl->backend() != otherTensor.mImpl->backend())) {
            return false;
        }
        return *mImpl == *(otherTensor.mImpl);
    }

    /**
     * @brief Set the backend of the Tensor associated implementation. If there
     * was no previous implementation set, data will be allocated, but it will
     * not be initialized to any particular value.
     * If data was already initialized in a previous backend, it will be moved
     * to the new one except if copyFrom is false.
     * @param name Backend name
     * @param device Backend device
     * @param copyFrom If true (default), move data from previous backend/device
     * to the new one. Previous data is lost otherwise.
     */
    inline void setBackend(const std::string &name, DeviceIdx_t device = 0, bool copyFrom = true) {
        if (mImpl) {
            if (mImpl->device() != std::make_pair(name, device)) {
                // Backend change: create new impl, copy from old to new and replace
                // impl
                std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({name, mDataType})(device, mImpl->size());
                if (copyFrom) {
                    newImpl->copyFrom(*mImpl, mImpl->size(), mImplOffset, 0);
                }
                setImpl(newImpl);
            }
        }
        else {
            mImpl = Registrar<Tensor>::create({name, mDataType})(device, mSize);
        }
    }

    /**
     * @brief Get a list of available backends.
     * @return std::set<std::string>
     */
    static std::set<std::string> getAvailableBackends(){
        std::set<std::string> backendsList;
        for(std::tuple<std::string, DataType> tupleKey : Registrar<Tensor>::getKeys())
            backendsList.insert(std::get<0>(tupleKey));
        return backendsList;
    }

    /**
     * @brief Get the data type enum.
     * @return constexpr DataType
     */
    constexpr DataType dataType() const { return mDataType; }

    /**
     * @brief Set the DataType of the Tensor and converts data
     * if the Tensor has already been initialized and copyCast is true.
     * @param dt DataType
     * @param copyCast If true (default), previous data is copy-casted. Otherwise
     * previous data is lost.
     */
    void setDataType(const DataType dt, bool copyCast = true) {
        if (mImpl && (dataType() != dt)) {
            std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), dt})(mImpl->device().second, mImpl->size());
            if (copyCast) {
                newImpl->copyCast(mImpl->rawPtr(mImplOffset), mDataType, mImpl->size());
            }
            setImpl(newImpl);
        }
        mDataType = dt;
    }

    /**
     * @brief Get the Impl object
     * @return constexpr const std::shared_ptr<TensorImpl>&
     */
    constexpr const std::shared_ptr<TensorImpl> &getImpl() const { return mImpl; }
    constexpr std::size_t getImplOffset() const { return mImplOffset; }

    /**
     * @brief Set the Impl object
     *
     * @param impl New impl shared pointer
     * @param implOffset Storage offset in this new impl for this Tensor
     */
    void setImpl(std::shared_ptr<TensorImpl> impl, std::size_t implOffset = 0) {
        mImpl = impl;
        mImplOffset = implOffset;
    }

    /**
     * @brief Return if an implementaiton has been associated.
     * @return true
     * @return false
     */
    bool hasImpl() const { return (mImpl) ? true : false; }

    /**
     * @brief Get number of dimensions of the Tensor.
     * @return std::size_t
     */
    inline std::size_t nbDims() const { return mDims.size(); }

    /**
     * @brief Get dimensions of the Tensor object.
     * @tparam DIM number of dimensions.
     * @return constexpr std::array<DimSize_t, DIM>
     */
    template <DimIdx_t DIM>
    constexpr std::array<DimSize_t, DIM> dims() const {
        assert(DIM == mDims.size() && "wrong number of dimensions");
        return to_array<DIM>(mDims.cbegin());
    }

    /**
     * @brief Get dimensions of the Tensor object.
     * @return constexpr const std::vector<DimSize_t>&
     */
    constexpr const std::vector<DimSize_t> &dims() const { return mDims; }

    /**
     * @brief Get strides of the Tensor object.
     * @return constexpr const std::vector<DimSize_t>&
     */
    constexpr const std::vector<DimSize_t> &strides() const { return mStrides; }

    /**
     * @brief Return true if Tensor is contiguous in memory.
     * @return bool
     */
    constexpr bool isContiguous() const { return mContiguous; }

    /**
     * @brief Get the number of elements in the Tensor object.
     * @return constexpr std::size_t
     */
    constexpr std::size_t size() const { return mSize; }

    /**
     * @brief Change the dimensions of the Tensor object according to the given argument.
     * If the overall size is not changed (meaning we actually only performed a
     * reshape), data is garanteed to remain valid.
     * Otherwise, no garantee is provided regarding the validy of previous data
     * (unlike std::vector). If the new overall size is larger than the previous
     * one, all previous data is invalided. Otherwise, previous data may or may
     * not remain valid, depending on the backend implementation.
     * @tparam DIM Number of dimensions.
     * @param dims New dimensions
     */
    template <std::array<DimSize_t, 1>::size_type DIM> // deducing std::array size_type and declaring DIM accordingly
    inline void resize(const std::array<DimSize_t, DIM> &dims) {
        resize(std::vector<DimSize_t>(dims.begin(), dims.end()));
    }

    /**
     * @brief Change the dimensions of the Tensor object according to the given argument.
     * If the overall size is not changed (meaning we actually only performed a
     * reshape), data is garanteed to remain valid.
     * Otherwise, no garantee is provided regarding the validy of previous data
     * (unlike std::vector). If the new overall size is larger than the previous
     * one, all previous data is invalided. Otherwise, previous data may or may
     * not remain valid, depending on the backend implementation.
     * @param dims New dimensions
     * @param strides Stride of the tensor (if not specified, "nested" stride is used)
     */
    void resize(const std::vector<DimSize_t> &dims, std::vector<DimSize_t> strides = std::vector<DimSize_t>());

    /**
     * @brief Return if the Tensor object has at leastone element.
     * @return true
     * @return false
     */
    bool empty() const { return mDims.empty(); }

    template <typename expectedType>
    const expectedType& get(std::size_t idx) const {
        AIDGE_ASSERT(NativeType<expectedType>::type == mDataType, "wrong data type");
        AIDGE_ASSERT(idx < mSize, "idx out of range");
        return *reinterpret_cast<expectedType *>(mImpl->hostPtr(mImplOffset + idx));
    }

    template <typename expectedType>
    const expectedType& get(std::vector<std::size_t> coordIdx) const {
        return get<expectedType>(getStorageIdx(coordIdx));
    }

    template <typename expectedType>
    void set(std::size_t idx, expectedType value){
        AIDGE_ASSERT(NativeType<expectedType>::type == mDataType, "wrong data type");
        AIDGE_ASSERT(idx < mSize, "idx out of range");
        expectedType* dataPtr = static_cast<expectedType*>(mImpl->hostPtr(mImplOffset + idx));
        *dataPtr = value;
    }

    template <typename expectedType>
    void set(std::vector<std::size_t> coordIdx, expectedType value){
        set<expectedType>(getStorageIdx(coordIdx), value);
    }

    std::string toString() const;

    inline void print() const { printf("%s\n", toString().c_str()); }

    std::shared_ptr<Tensor> grad() {
        if (!mGrad) {
            mGrad = std::make_shared<Tensor>(mDataType);
            mGrad->resize(mDims);

            if (mImpl) mGrad->setBackend(mImpl->backend());
        }

        return mGrad;
    }

    /**
     * @brief From the the 1D contiguous index, return the coordinate of an element in the tensor.
     * Beware: do not use this function with the storage index!
     *
     * @param flatIdx 1D contiguous index of the value considering a flatten, contiguous, tensor.
     * @return std::vector<DimSize_t>
     */
    std::vector<std::size_t> getCoord(std::size_t flatIdx) const {
        std::vector<std::size_t> coordIdx = std::vector<std::size_t>(mDims.size());
        std::size_t idx = flatIdx;
        for (std::size_t i = mDims.size() - 1; i > 0; --i){
            coordIdx[i] = (idx % mDims[i]);
            idx/=mDims[i];
        }
        coordIdx[0] = idx % mDims[0];
        return coordIdx;
    }

    /**
     * @brief From the coordinate returns the 1D contiguous index of an element in the tensor.
     * If the number of coordinates is inferior to the number of dimensions,
     * the remaining coordinates are assumed to be 0.
     * Beware: the contiguous index will only correspond to the storage index
     * if the tensor is contiguous!
     *
     * @param coordIdx Coordinate to an element in the tensor
     * @return DimSize_t Contiguous index
     */
    std::size_t getIdx(const std::vector<std::size_t>& coordIdx) const {
        AIDGE_ASSERT(coordIdx.size() <= mDims.size(), "Coordinates does not match number of dimensions");
        std::size_t flatIdx = 0;
        std::size_t i = 0;
        for(; i < coordIdx.size() - 1; ++i){
            AIDGE_ASSERT(coordIdx[i] < mDims[i], "Coordinates dimensions does not fit the dimensions of the tensor");
            flatIdx = (flatIdx + coordIdx[i]) * mDims[i + 1];
        }
        return flatIdx + coordIdx[i];
    }

    /**
     * @brief From the coordinate returns the 1D storage index of an element in the tensor.
     * If the number of coordinates is inferior to the number of dimensions,
     * the remaining coordinates are assumed to be 0.
     *
     * @param coordIdx Coordinate to an element in the tensor
     * @return DimSize_t Storage index
     */
    std::size_t getStorageIdx(const std::vector<std::size_t>& coordIdx) const {
        AIDGE_ASSERT(coordIdx.size() <= mDims.size(), "Coordinates does not match number of dimensions");
        return std::inner_product(coordIdx.begin(), coordIdx.end(), mStrides.begin(), DimSize_t(0));
    }

    /**
     * @brief Returns a sub-tensor with one or more dimension less.
     * For instance, t.extract({1}) on a CHW tensor will return the HW tensor
     * of channel #1.
     * Likewise, t.extract({0, 1}) on a NCHW tensor will return the HW tensor
     * of batch #0 and channel #1.
     * No memory copy is performed, the returned tensor does not own the memory.
     * If the number of coordinates matches the number of dimensions, an empty
     * tensor is returned.
     * It current tensor was contiguous, the returned tensor is garanteed to be
     * contiguous as well.
     *
     * @param coordIdx Coordinates of the sub-tensor to extract
     * @return Tensor Sub-tensor.
    */
    Tensor extract(const std::vector<std::size_t>& coordIdx) const;

    /**
     * @brief Returns a sub-tensor at some coordinate and with some dimension.
     *
     * @param coordIdx First coordinates of the sub-tensor to extract
     * @param dims Dimensions of the sub-tensor to extract
     * @return Tensor Sub-tensor.
    */
    Tensor extract(const std::vector<std::size_t>& coordIdx, const std::vector<std::size_t>& dims) const;

    /**
     * @brief Make the tensor's storage contiguous, if it is not already the case.
     * If not contiguous, a new memory space is allocated.
    */
    void makeContiguous();

    /**
     * Copy-cast data from a Tensor on the same device.
     * If current tensor backend/device is set and is different from src, an
     * assertion is raised.
     * @param src Source tensor to copy-cast from.
    */
    void copyCast(const Tensor& src);

    /**
     * Copy data from a Tensor from another backend/device.
     * If current tensor data type is set and is different from src, an
     * assertion is raised.
     * @param src Source tensor to copy from.
    */
    void copyFrom(const Tensor& src);

    /**
     * Copy-cast data from a Tensor.
     * @param src Source tensor to copy-cast from.
     * @param movedSrc shared_ptr to an indermediate Tensor that will
     * contain the moved data if a device change should occur AND a type
     * conversion is necessary (otherwise it remains unused).
     * Any data already present will be overwritten. No new memory allocation
     * will occur if movedSrc has already been allocated with the right
     * type/size/device.
     * If required, memory is always allocated on current (destination)
     * Tensor's device.
    */
    void copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& movedSrc);

    /**
     * Copy-cast data from a Tensor.
     * In case of both a device change AND a data type conversion, an
     * intermediate buffer on will be allocated and deallocated each time.
     * If required, buffer's memory is always allocated on current (destination)
     * Tensor's device.
     * @param src Source tensor to copy-cast from.
    */
    void copyCastFrom(const Tensor& src) {
        // Internal buffer will be allocated and deallocated at each call
        // (only if needed)
        std::shared_ptr<Tensor> movedSrc;
        copyCastFrom(src, movedSrc);
    }

    /**
     * Return a reference to a Tensor that is garanteed to be contiguous:
     * - itself, if already contiguous;
     * - the provided Tensor, overwritten with the copied data.
     * The data type, backend and device stay the same.
     * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
     * The shared_ptr does not need to be initialized. No new memory allocation
     * will occur if fallback has already been allocated with the right
     * type/size/device.
     * @return Reference to either itself or to fallback.
    */
    Tensor& refContiguous(std::shared_ptr<Tensor>& fallback);
    const Tensor& refContiguous(std::shared_ptr<Tensor>& fallback) const;

    /**
     * Return a reference to a Tensor casted to the desired data type:
     * - itself, if already at the right data type;
     * - the provided Tensor, overwritten with the copy-casted data.
     * The backend stays the same.
     * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
     * The shared_ptr does not need to be initialized. No new memory allocation
     * will occur if fallback has already been allocated with the right
     * type/size/device.
     * @param dt The desired data type.
     * @return Reference to either itself or to fallback.
    */
    Tensor& refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt);
    const Tensor& refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt) const;

    /**
     * Return a reference to a Tensor on the desired backend/device:
     * - itself, if already on the right device;
     * - the provided Tensor, overwritten with the copied data.
     * The data type stays the same.
     * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
     * The shared_ptr does not need to be initialized. No new memory allocation
     * will occur if fallback has already been allocated with the right
     * type/size/device.
     * @param backend The desired backend.
     * @param device The desired device.
     * @return Reference to either itself or to fallback.
    */
    Tensor& refFrom(std::shared_ptr<Tensor>& fallback, const std::string &backend, DeviceIdx_t device = 0);
    const Tensor& refFrom(std::shared_ptr<Tensor>& fallback, const std::string &backend, DeviceIdx_t device = 0) const;

    /**
     * Return a reference to a Tensor on desired data type and backend/device:
     * - itself, if already with the right characteristics;
     * - the provided Tensor, overwritten with the copy-casted data.
     * If required, fallback is always allocated on desired (destination)
     * device.
     * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
     * The shared_ptr does not need to be initialized. No new memory allocation
     * will occur if fallback has already been allocated with the right
     * type/size/device.
     * @param dt The desired data type.
     * @param backend The desired backend.
     * @param device The desired device.
     * @return Reference to either itself or to fallback.
    */
    Tensor& refCastFrom(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device = 0) {
        // First refFrom, to ensure that fallback, if required, is also on desired device
        return refFrom(fallback, backend, device).refCast(fallback, dt);
    }

    /**
     * Return a reference to a Tensor with same characteristics
     * (data type, backend/device) as targetReqs Tensor:
     * - itself, if already with the right characteristics;
     * - the provided Tensor, overwritten with the copy-casted data.
     * If required, fallback is always allocated on current (destination)
     * Tensor's device.
     * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
     * The shared_ptr does not need to be initialized. No new memory allocation
     * will occur if fallback has already been allocated with the right
     * type/size/device.
     * @param targetReqs Tensor with the desired target characteristics.
     * @return Reference to either itself or to fallback.
    */
    Tensor& refCastFrom(std::shared_ptr<Tensor>& fallback, const Tensor& targetReqs) {
        const auto& device = targetReqs.getImpl()->device();
        return refCastFrom(fallback, targetReqs.dataType(), device.first, device.second);
    }

    /**
     * @brief Return a reference to a Tensor on desired data type and backend/device:
     * - itself, if already with the right characteristics;
     * - the provided Tensor, overwritten with the right characteristics.
     * @note no data is copy-casted. If it was so in a previous refCastFrom() on
     * the same fallback, it remains valid, otherwise, data is invalid.
     * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
     * The shared_ptr does not need to be initialized. No new memory allocation
     * will occur if fallback has already been allocated with the right
     * type/size/device.
     * @param dt The desired data type.
     * @param backend The desired backend.
     * @param device The desired device.
     * @return Reference to either itself or to fallback.
    */
    Tensor& ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device = 0);
    const Tensor& ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device = 0) const;

    /**
     * @brief Return a reference to a Tensor with same characteristics
     * (data type, backend/device) as targetReqs Tensor:
     * - itself, if already with the right characteristics;
     * - the provided Tensor, overwritten with the right characteristics.
     * @note no data is copy-casted. If it was so in a previous refCastFrom() on
     * the same fallback, it remains valid, otherwise, data is invalid.
     * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
     * The shared_ptr does not need to be initialized. No new memory allocation
     * will occur if fallback has already been allocated with the right
     * type/size/device.
     * @param targetReqs Tensor with the desired target characteristics.
     * @return Reference to either itself or to fallback.
    */
    Tensor& ref(std::shared_ptr<Tensor>& fallback, const Tensor& targetReqs) {
        const auto& device = targetReqs.getImpl()->device();
        return ref(fallback, targetReqs.dataType(), device.first, device.second);
    }

private:
    /**
     * @brief Compute the number of elements in the Tensor.
     * @note If dimensions are not empty, they are multiplied to get the total number
     * of elements. Else, the Tensor represents a scalar and contains a single element.
     */
    void computeSize() {
        mSize = std::accumulate(mDims.begin(), mDims.end(), DimSize_t(1), std::multiplies<DimSize_t>());
    }
};
}  // namespace Aidge

#endif /* AIDGE_CORE_DATA_TENSOR_H_ */