Skip to content
Snippets Groups Projects
Forked from Eclipse Projects / aidge / aidge_core
2015 commits behind the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Tensor.hpp 25.79 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#ifndef AIDGE_CORE_DATA_TENSOR_H_
#define AIDGE_CORE_DATA_TENSOR_H_

#include <cstring>
#include <set>
#include <memory>
#include <numeric>
#include <string>
#include <vector>

#include "aidge/backend/TensorImpl.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"

namespace Aidge {

// Helper to create default arrays
template <typename T, std::size_t ... Is>
constexpr std::array<T, sizeof...(Is)>
create_array_impl(T value, std::index_sequence<Is...>)
{
    // cast Is to void to remove the warning: unused value
    return {{(static_cast<void>(Is), value)...}};
}

template <typename T, std::size_t N>
constexpr std::array<T, N> create_array(const T& value)
{
    return create_array_impl(value, std::make_index_sequence<N>());
}


// Helper to convert vector to array
template <typename T, typename Iter, std::size_t... Is>
constexpr auto to_array(Iter &iter, std::index_sequence<Is...>) -> std::array<T, sizeof...(Is)> {
    return {{((void)Is, T(*iter++))...}};
}

/**
 * @brief Convert an object with an iterator to an std::array.
 */
template <std::size_t N, typename U = void, typename Iter, typename V = typename std::iterator_traits<Iter>::value_type,
          typename T = std::conditional_t<std::is_same<U, void>{}, V, U>>
constexpr auto to_array(Iter iter) -> std::array<T, N> {
    return to_array<T>(iter, std::make_index_sequence<N>{});
}

namespace detail {

template <class T, std::size_t N, std::size_t... I>
constexpr std::array<std::remove_cv_t<T>, N> to_array_impl(T (&a)[N], std::index_sequence<I...>) {
    return {{a[I]...}};
}

}  // namespace detail

/**
 * @brief Convert a C-stype array into a C++ std::array.
 *
 * @tparam T Data type.
 * @tparam N Number of elements.
 * @param a C-style array to convert.
 * @return constexpr std::array<std::remove_cv_t<T>, N>
 */
template <class T, std::size_t N>
constexpr std::array<std::remove_cv_t<T>, N> to_array(T (&a)[N]) {
    return detail::to_array_impl(a, std::make_index_sequence<N>{});
}

template <typename T, std::size_t N, std::size_t... I>
constexpr std::array<T, N + 1> append(std::array<T, N> a, T t, std::index_sequence<I...>) {
    return std::array<T, N + 1>{a[I]..., t};
}

template <typename T, std::size_t N, std::size_t... I>
constexpr std::array<T, N + 1> append(T t, std::array<T, N> a, std::index_sequence<I...>) {
    return std::array<T, N + 1>{t, a[I]...};
}

/**
 * @brief Create a new array concatenating the initial one with the value to
 * add.
 * @details append({1,2,7}, 3) -> {1,2,7,3}
 *
 * @tparam T Data type.
 * @tparam N Number of elements in the initilial array.
 * @param a Initial array.
 * @param t Element to add.
 * @return constexpr std::array<T, N + 1>
 */
template <typename T, std::size_t N>
constexpr std::array<T, N + 1> append(std::array<T, N> a, T t) {
    return append(a, t, std::make_index_sequence<N>());
}

template <typename T, std::size_t N>
constexpr std::array<T, N + 1> append(T t, std::array<T, N> a) {
    return append(t, a, std::make_index_sequence<N>());
}

// Generic helper for initializing a Tensor
template <typename T, std::size_t SIZE_0>
struct Array1D {
    T data[SIZE_0];
};

template <typename T, std::size_t SIZE_0, std::size_t SIZE_1>
struct Array2D {
    T data[SIZE_0][SIZE_1];
};

template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2>
struct Array3D {
    T data[SIZE_0][SIZE_1][SIZE_2];
};

template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2, std::size_t SIZE_3>
struct Array4D {
    T data[SIZE_0][SIZE_1][SIZE_2][SIZE_3];
};

/**
 * @brief Description for the tensor data structure.
 * @details Sets the properties of the tensor without actually containing any data.
 * Contains a pointer to an actual contiguous implementation of data.
 */
class Tensor : public Data,
               public Registrable<Tensor, std::tuple<std::string, DataType>, std::unique_ptr<TensorImpl>(const Tensor &)> {
   private:
    DataType mDataType; /** enum to specify data type. */
    std::vector<DimSize_t> mDims; /** Dimensions of the tensor. */
    std::unique_ptr<TensorImpl> mImpl; /** Pointer to the actual data implementation. */
    std::shared_ptr<Tensor> mGrad; /** Pointer to the associated gradient Tensor instance. */

    // Cached data
    std::size_t mSize;    /** Number of elements in the Tensor. */
    std::size_t mSizeM1;  /** Number of elements in the N-1 first dimensions */

   public:
    static constexpr const char *Type = "Tensor";

    /**
     * @brief Construct a new empty Tensor object.
     * @param dataType Sets the type of inserted data.
     */
    Tensor(DataType dataType = DataType::Float32)
        : Data(Type),
          mDataType(dataType),
          mDims({}),
          mSize(0),
          mSizeM1(0)
    {
        // ctor
    }

    /**
     * @brief Construct a new Tensor object copied from another one.
     * @param otherTensor
     */
    Tensor(const Tensor& otherTensor)
        : Data(Type),
          mDataType(otherTensor.mDataType),
          mDims(otherTensor.mDims),
          mSize(otherTensor.mSize),
          mSizeM1(otherTensor.mSizeM1)
    {
        if (otherTensor.hasImpl()) {
            mImpl = Registrar<Tensor>::create({otherTensor.mImpl->backend(), dataType()})(*this);
            mImpl->setDevice(otherTensor.mImpl->device().second);
            // Same backend, same device => directly use copy()
            mImpl->copy(otherTensor.mImpl->rawPtr(), mSize);
        }
    }

    /**
     * @brief Construct a new Tensor object from the 1-dimension Array helper.
     * @tparam T datatype
     * @tparam SIZE_0 first array dimension.
     */
    template <typename T, std::size_t SIZE_0>
    constexpr Tensor(Array1D<T, SIZE_0> &&arr)
        : Data(Type),
          mDataType(NativeType<T>::type),
          mDims({SIZE_0}),
          mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)),
          mSize(SIZE_0),
          mSizeM1(SIZE_0) {
        mImpl->copyFromHost(&arr.data[0], SIZE_0);
    }

    template <typename T, std::size_t SIZE_0>
    constexpr Tensor &operator=(Array1D<T, SIZE_0> &&arr) {
        resize({SIZE_0});
        if (!mImpl) {
            mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this);
        }
        mImpl->copyFromHost(&arr.data[0], SIZE_0);
        return *this;
    }

    /**
     * @brief Construct a new Tensor object from the 2-dimensions Array helper.
     * @tparam T datatype
     * @tparam SIZE_0 first array dimension.
     * @tparam SIZE_1 second array dimension.
     */
    template <typename T, std::size_t SIZE_0, std::size_t SIZE_1>
    constexpr Tensor(Array2D<T, SIZE_0, SIZE_1> &&arr)
        : Data(Type),
          mDataType(NativeType<T>::type),
          mDims({SIZE_0, SIZE_1}),
          mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)),
          mSize(SIZE_0 * SIZE_1),
          mSizeM1(SIZE_1) {
        mImpl->copyFromHost(&arr.data[0][0], SIZE_0 * SIZE_1);
    }

    template <typename T, std::size_t SIZE_0, std::size_t SIZE_1>
    constexpr Tensor &operator=(Array2D<T, SIZE_0, SIZE_1> &&arr) {
        resize({SIZE_0, SIZE_1});
        if (!mImpl) {
            mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this);
        }
        mImpl->copyFromHost(&arr.data[0][0], SIZE_0 * SIZE_1);
        return *this;
    }

    /**
     * @brief Construct a new Tensor object from the 3-dimensions Array helper.
     * @tparam T datatype
     * @tparam SIZE_0 first array dimension.
     * @tparam SIZE_1 second array dimension.
     * @tparam SIZE_2 third array dimension.
     */
    template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2>
    constexpr Tensor(Array3D<T, SIZE_0, SIZE_1, SIZE_2> &&arr)
        : Data(Type),
          mDataType(NativeType<T>::type),
          mDims({SIZE_0, SIZE_1, SIZE_2}),
          mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)),
          mSize(SIZE_0 * SIZE_1 * SIZE_2),
          mSizeM1(SIZE_1 * SIZE_2) {
        mImpl->copyFromHost(&arr.data[0][0][0], SIZE_0 * SIZE_1 * SIZE_2);
    }

    template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2>
    constexpr Tensor &operator=(Array3D<T, SIZE_0, SIZE_1, SIZE_2> &&arr) {
        resize({SIZE_0, SIZE_1, SIZE_2});
        if (!mImpl) {
            mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this);
        }
        mImpl->copyFromHost(&arr.data[0][0][0], SIZE_0 * SIZE_1 * SIZE_2);
        return *this;
    }

    /**
     * @brief Construct a new Tensor object from the 4-dimensions Array helper.
     * @tparam T datatype
     * @tparam SIZE_0 first array dimension.
     * @tparam SIZE_1 second array dimension.
     * @tparam SIZE_2 third array dimension.
     * @tparam SIZE_3 fourth array dimension.
     */
    template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2, std::size_t SIZE_3>
    constexpr Tensor(Array4D<T, SIZE_0, SIZE_1, SIZE_2, SIZE_3> &&arr)
        : Data(Type),
          mDataType(NativeType<T>::type),
          mDims({SIZE_0, SIZE_1, SIZE_2, SIZE_3}),
          mImpl(Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this)),
          mSize(SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3),
          mSizeM1(SIZE_1 * SIZE_2 * SIZE_3) {
        mImpl->copyFromHost(&arr.data[0][0][0][0], SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3);
    }

    template <typename T, std::size_t SIZE_0, std::size_t SIZE_1, std::size_t SIZE_2, std::size_t SIZE_3>
    constexpr Tensor &operator=(Array4D<T, SIZE_0, SIZE_1, SIZE_2, SIZE_3> &&arr) {
        resize({SIZE_0, SIZE_1, SIZE_2, SIZE_3});
        if (!mImpl) {
            mImpl = Registrar<Tensor>::create({"cpu", NativeType<T>::type})(*this);
        }
        mImpl->copyFromHost(&arr.data[0][0][0][0], SIZE_0 * SIZE_1 * SIZE_2 * SIZE_3);
        return *this;
    }

    /**
     * @brief Copy dimensions, datatype and data of another Tensor.
     * @param t other Tensor object.
     * @return Tensor&
     */
    Tensor &operator=(const Tensor &t) {
        resize(t.dims());
        setDataType(t.dataType());
        if (t.hasImpl()) {
            if (hasImpl()) {
                copyCastFrom(t);
            }
            else {
                mImpl = Registrar<Tensor>::create({t.mImpl->backend(), dataType()})(*this);
                mImpl->setDevice(t.mImpl->device().second);
                // Same backend, same device => directly use copy()
                mImpl->copy(t.mImpl->rawPtr(), mSize);
            }
        }
        else {
            mImpl = nullptr;
        }
        return *this;
    }

    /**
     * @brief Assess data type, dimensions, backend and data are the same.
     * @param otherTensor
     */
    bool operator==(const Tensor &otherTensor) const {
        if ((!mImpl && !otherTensor.mImpl) || (dataType() != otherTensor.dataType()) ||
            (dims() != otherTensor.dims()) || (mImpl->backend() != otherTensor.mImpl->backend())) {
            return false;
        }
        return *mImpl == *(otherTensor.mImpl);
    }

    /**
     * @brief Set the backend of the Tensor associated implementation
     * @details Create and initialized an implementation if non was associated.
     * @param name
     */
    inline void setBackend(const std::string &name, int device = 0) {
        if (mImpl) {
            if (mImpl->device() != std::make_pair(name, device)) {
                // Backend change: create new impl, copy from old to new and replace
                // impl
                std::unique_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({name, mDataType})(*this);
                newImpl->setDevice(device);
                newImpl->copyFrom(*mImpl, size());
                mImpl = std::move(newImpl);
            }
        }
        else {
            mImpl = Registrar<Tensor>::create({name, mDataType})(*this);
            mImpl->setDevice(device);
        }
    }

    /**
     * @brief Get a list of available backends.
     * @return std::set<std::string>
     */
    static std::set<std::string> getAvailableBackends(){
        std::set<std::string> backendsList;
        for(std::tuple<std::string, DataType> tupleKey : Registrar<Tensor>::getKeys())
            backendsList.insert(std::get<0>(tupleKey));
        return backendsList;
    }

    /**
     * @brief Get the data type enum.
     * @return constexpr DataType
     */
    constexpr DataType dataType() const { return mDataType; }

    /**
     * @brief Set the DataType of the Tensor and converts data
     * if the Tensor has already been initialized.
     * @param dt DataType.
     */
    void setDataType(const DataType dt) {
        if (mImpl && (dataType() != dt)) {
            std::unique_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), dt})(*this);
            newImpl->copyCast(mImpl->rawPtr(), size(), mDataType);
            mImpl = std::move(newImpl);
        }
        mDataType = dt;
    }

    /**
     * @brief Get the Impl object
     * @return constexpr const std::unique_ptr<TensorImpl>&
     */
    constexpr const std::unique_ptr<TensorImpl> &getImpl() { return mImpl; }
    constexpr const std::unique_ptr<TensorImpl> &getImpl() const { return mImpl; }

    /**
     * @brief Return if an implementaiton has been associated.
     * @return true
     * @return false
     */
    bool hasImpl() const { return (mImpl) ? true : false; }

    /**
     * @brief Get number of dimensions of the Tensor.
     * @return std::size_t
     */
    inline std::size_t nbDims() const { return mDims.size(); }

    /**
     * @brief Get dimensions of the Tensor object.
     * @tparam DIM number of dimensions.
     * @return constexpr std::array<DimSize_t, DIM>
     */
    template <DimIdx_t DIM>
    constexpr std::array<DimSize_t, DIM> dims() const {
        assert(DIM == mDims.size() && "wrong number of dimensions");
        return to_array<DIM>(mDims.cbegin());
    }

    /**
     * @brief Get dimensions of the Tensor object.
     * @return constexpr const std::vector<DimSize_t>&
     */
    constexpr const std::vector<DimSize_t> &dims() const { return mDims; }

    /**
     * @brief Get the number of elements in the Tensor object.
     * @return constexpr std::size_t
     */
    constexpr std::size_t size() const { return mSize; }

    /**
     * @brief Get the number of elements in the N-1 dimensions of the Tensor object.
     * @return constexpr std::size_t
     */
    constexpr std::size_t sizeM1() const { return mSizeM1; }

    /**
     * @brief Change the shape of the Tensor object according to the given argument.
     * @tparam DIM new dimensions.
     * @param dims
     */
    template <std::array<DimSize_t, 1>::size_type DIM> // deducing std::array size_type and declaring DIM accordingly
    void resize(const std::array<DimSize_t, DIM> &dims) {
        static_assert(DIM<=MaxDim,"Too many tensor dimensions required by resize, not supported");
        mDims.assign(dims.begin(), dims.end());
        computeSize();
    }

    void resize(const std::vector<DimSize_t> &dims) {
        mDims = dims;
        computeSize();
    }

    /**
     * @brief Return if the Tensor object has at leastone element.
     * @return true
     * @return false
     */
    bool empty() const { return mDims.empty(); }

    template <typename expectedType>
    expectedType& get(std::size_t idx){
        // TODO : add assert expected Type compatible with datatype
        // TODO : add assert idx < Size
        return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx));
    }

    template <typename expectedType>
    expectedType& get(std::vector<std::size_t> coordIdx){
        return get<expectedType>(getIdx(coordIdx));
    }

    template <typename expectedType>
    void set(std::size_t idx, expectedType value){
        // TODO : add assert expected Type compatible with datatype
        // TODO : add assert idx < Size
        void* dataPtr = mImpl->getRaw(idx);
        std::memcpy(dataPtr, &value, sizeof(expectedType));
    }

    template <typename expectedType>
    void set(std::vector<std::size_t> coordIdx, expectedType value){
        set<expectedType>(getIdx(coordIdx), value);
    }



    std::string toString() const {
        // TODO: move lambda elsewhere?
        auto ptrToString = [](DataType dt, void* ptr, size_t idx) {
            switch (dt) {
            case DataType::Float64:
                return std::to_string(static_cast<double*>(ptr)[idx]);
            case DataType::Float32:
                return std::to_string(static_cast<float*>(ptr)[idx]);
            case DataType::Float16:
                return std::to_string(static_cast<half_float::half*>(ptr)[idx]);
            case DataType::Int8:
                return std::to_string(static_cast<int8_t*>(ptr)[idx]);
            case DataType::Int16:
                return std::to_string(static_cast<int16_t*>(ptr)[idx]);
            case DataType::Int32:
                return std::to_string(static_cast<int32_t*>(ptr)[idx]);
            case DataType::Int64:
                return std::to_string(static_cast<int64_t*>(ptr)[idx]);
            case DataType::UInt8:
                return std::to_string(static_cast<uint8_t*>(ptr)[idx]);
            case DataType::UInt16:
                return std::to_string(static_cast<uint16_t*>(ptr)[idx]);
            case DataType::UInt32:
                return std::to_string(static_cast<uint32_t*>(ptr)[idx]);
            case DataType::UInt64:
                return std::to_string(static_cast<uint64_t*>(ptr)[idx]);
            default:
                AIDGE_ASSERT(true, "unsupported type to convert to string");
            }
        };

        if (dims().empty()) { return "{}"; }
        std::string res;
        std::size_t dim = 0;
        std::size_t counter = 0;
        if (nbDims()>=2) {
            std::vector<std::size_t> dimVals(nbDims(), 0);
            res += "{\n";
            while (counter < mSize) {
                std::string spaceString = std::string((dim+1)<<1,' ');
                if (dim < nbDims()-2) {
                    if (dimVals[dim] == 0) {
                        res += spaceString + "{\n";
                        ++dim;
                    } else if (dimVals[dim] < static_cast<std::size_t>(dims()[dim])) {
                        res += spaceString + "},\n" + spaceString + "{\n";
                        ++dim;
                    } else {
                        res += spaceString + "}\n";
                        dimVals[dim--] = 0;
                        dimVals[dim]++;
                    }
                } else {
                    for (; dimVals[dim] < static_cast<std::size_t>(dims()[dim]); ++dimVals[dim]) {
                        res += spaceString + "{";
                        for (DimSize_t j = 0; j < dims()[dim + 1] - 1; ++j) {
                            res += " " + ptrToString(mDataType, mImpl->rawPtr(), counter++) + ",";
                        }
                        res += " " + ptrToString(mDataType, mImpl->rawPtr(), counter++) + "}";
                        if (dimVals[dim] < static_cast<std::size_t>(dims()[dim] - 1)) {
                            res += ",";
                        }
                        res += "\n";
                    }
                    if (dim == 0) {
                        break;
                    }
                    dimVals[dim--] = 0;
                    dimVals[dim]++;
                }
            }

            for(int i = static_cast<int>(dim); i > 0; --i) {
                res += std::string((dim+1)<<1,' ') + "}\n";
            }
        } else {
            res += "{";
            for (DimSize_t j = 0; j < dims()[0]; ++j) {
                res += " " + ptrToString(mDataType, mImpl->rawPtr(), j) + ((j < dims()[0]-1) ? "," : "");
            }
        }
        res += "}";
        return res;
    }

    inline void print() const { printf("%s\n", toString().c_str()); }

    std::shared_ptr<Tensor> grad() {
        if (!mGrad) {
            mGrad = std::make_shared<Tensor>(mDataType);
            mGrad->resize(mDims);

            if (mImpl) mGrad->setBackend(mImpl->backend());
        }

        return mGrad;
    }

    /**
     * @brief From the the 1D index, return the coordinate of an element in the tensor.
     *
     * @param flatIdx 1D index of the value considering a flatten tensor.
     * @return std::vector<DimSize_t>
     */
    std::vector<std::size_t> getCoord(std::size_t flatIdx) const {
        std::vector<std::size_t> coordIdx = std::vector<std::size_t>(mDims.size());
        std::size_t idx = flatIdx;
        for (std::size_t i = mDims.size() - 1; i > 0; --i){
            coordIdx[i] = (idx % mDims[i]);
            idx/=mDims[i];
        }
        coordIdx[0] = idx % mDims[0];
        return coordIdx;
    }

    /**
     * @brief From the coordinate returns the 1D index of an element in the tensor.
     *
     * @param coordIdx Coordinate to an element in the tensor
     * @return DimSize_t
     */
    std::size_t getIdx(std::vector<std::size_t> coordIdx) const {
        // std::size_t flatIdx = 0;
        // std::size_t stride = 1;
        std::size_t flatIdx = 0;
        assert(coordIdx.size() == mDims.size() && "Coordinates does not match number of dimensions");
        std::size_t i = 0;
        for(; i < mDims.size() - 1; ++i){
            assert(coordIdx[i] < mDims[i] && "Coordinates dimensions does not fit the dimensions of the tensor");
            flatIdx = (flatIdx + coordIdx[i]) * mDims[i + 1];
        }
        return flatIdx + coordIdx[i];
    }

    /**
     * Copy-cast data from a Tensor.
     * @param src Source tensor to copy-cast from.
     * @param movedSrc shared_ptr to an indermediate Tensor that will 
     * contain the moved data if a device change should occur AND a type 
     * conversion is necessary (otherwise it remains unused).
     * Any data already present will be overwritten. No new memory allocation 
     * will occur if movedSrc has already been allocated with the right 
     * type/size/device.
     * If required, memory is always allocated on current (destination) 
     * Tensor's device.
    */
    void copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& movedSrc);

    /**
     * Copy-cast data from a Tensor.
     * In case of both a device change AND a data type conversion, an 
     * intermediate buffer on will be allocated and deallocated each time.
     * If required, buffer's memory is always allocated on current (destination) 
     * Tensor's device.
     * @param src Source tensor to copy-cast from.
    */
    void copyCastFrom(const Tensor& src) {
        // Internal buffer will be allocated and deallocated at each call
        // (only if needed)
        std::shared_ptr<Tensor> movedSrc;
        copyCastFrom(src, movedSrc);
    }

    /**
     * Return a reference to a Tensor casted to the desired data type:
     * - itself, if already at the right data type;
     * - the provided Tensor, overwritten with the copy-casted data.
     * The backend stays the same.
     * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
     * The shared_ptr does not need to be initialized. No new memory allocation
     * will occur if fallback has already been allocated with the right 
     * type/size/device.
     * @param dt The desired data type.
     * @return Reference to either itself or to fallback.
    */
    Tensor& refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt);
    const Tensor& refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt) const;

    /**
     * Return a reference to a Tensor on the desired backend/device:
     * - itself, if already on the right device;
     * - the provided Tensor, overwritten with the copied data.
     * The data type stays the same.
     * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
     * The shared_ptr does not need to be initialized. No new memory allocation
     * will occur if fallback has already been allocated with the right 
     * type/size/device.
     * @param backend The desired backend.
     * @param device The desired device.
     * @return Reference to either itself or to fallback.
    */
    Tensor& ref(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device = 0);
    const Tensor& ref(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device = 0) const;

    /**
     * Return a reference to a Tensor with same characteristics
     * (data type, backend/device) as target Tensor:
     * - itself, if already with the right characteristics;
     * - the provided Tensor, overwritten with the copy-casted data.
     * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
     * The shared_ptr does not need to be initialized. No new memory allocation
     * will occur if fallback has already been allocated with the right 
     * type/size/device.
     * @param target Tensor with the desired target characteristics.
     * @return Reference to either itself or to fallback.
    */
    Tensor& refCast(std::shared_ptr<Tensor>& fallback, const Tensor& target) {
        const auto& device = target.getImpl()->device();
        return refCast(fallback, target.dataType()).ref(fallback, device.first, device.second);
    }

private:
    ///\bug not protected against overflow
    std::size_t computeSize() {
        if (mDims.empty()) {
            mSizeM1 = DimSize_t(0);
            mSize = DimSize_t(0);
        }
        else if (mDims.size() == 1)
        {
            mSizeM1 = mDims[0];
            mSize = mDims[0];
        }
        else {
            mSizeM1 = std::accumulate(++mDims.begin(),mDims.end(), DimSize_t(1), std::multiplies<DimSize_t>());
            mSize = static_cast<std::size_t>(mSizeM1 * mDims[0]);
        }

        return mSize;
    }
};
}  // namespace Aidge

#endif /* AIDGE_CORE_DATA_TENSOR_H_ */