Skip to content
Snippets Groups Projects
TensorImpl.hpp 4.49 KiB
Newer Older
Cyril Moineau's avatar
Cyril Moineau committed
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#ifndef AIDGE_TENSORIMPL_H_
#define AIDGE_TENSORIMPL_H_
Cyril Moineau's avatar
Cyril Moineau committed

#include <cstddef>
#include <cstdio>
#include "aidge/data/Data.hpp"
#include "aidge/utils/Types.h"
Olivier BICHLER's avatar
Olivier BICHLER committed
#include "aidge/utils/ErrorHandling.hpp"
Cyril Moineau's avatar
Cyril Moineau committed

namespace Aidge {
Olivier BICHLER's avatar
Olivier BICHLER committed
/**
 * This class manages the raw data storage of a Tensor and provide generic copy
 * primitives from other devices and from/to host.
 * It can own the data or not (use setRawPtr() to set an external data owner).
 * It only knows the data type and data capacity, but does not handle anything else.
*/
Cyril Moineau's avatar
Cyril Moineau committed
class TensorImpl {
public:
    TensorImpl() = delete;
    TensorImpl(const char *backend, int device = 0) : mBackend(backend), mDevice(device){};

    /**
     * Return the (backend, device) pair for this implementation.
    */
    std::pair<std::string, int> device() const { return std::make_pair(mBackend, mDevice); }

    /**
     * Set the device ID for current backend.
     * @param device New device ID on current backend.
    */
    virtual void setDevice(int device) = 0;

    /**
     * Copy data from the same device.
     * @param src Pointer on current implementation device.
Olivier BICHLER's avatar
Olivier BICHLER committed
     * @param length Number of elements to copy.
     * @param offset Destination offset (in number of elements).
    virtual void copy(const void *src, NbElts_t length, NbElts_t offset = 0) = 0;

    /**
     * Copy-convert data from the same device.
     * @param srcDt Source data type.
     * @param src Pointer on current implementation device.
Olivier BICHLER's avatar
Olivier BICHLER committed
     * @param length Number of elements to copy.
    */
    virtual void copyCast(const void *src, NbElts_t length, const DataType srcDt) = 0;

    /**
     * Copy data from an other device on the same backend.
     * @param device (backend, device) pair to copy from. The backend must match current implementation backend.
     * @param src Pointer on current implementation backend.
Olivier BICHLER's avatar
Olivier BICHLER committed
     * @param length Number of elements to copy.
    */
    virtual void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, int>& device) = 0;

    /**
     * Copy data from host.
     * @param src Host pointer to copy from.
Olivier BICHLER's avatar
Olivier BICHLER committed
     * @param length Number of elements to copy.
    */
    virtual void copyFromHost(const void *src, NbElts_t length) = 0;

    /**
     * Copy data to host.
     * @param src Host pointer to copy to.
Olivier BICHLER's avatar
Olivier BICHLER committed
     * @param length Number of elements to copy.
    virtual void copyToHost(void *dst, NbElts_t length) const = 0;

    /**
     * Return the raw device pointer.
     * The raw pointer is garanteed to be valid only on the *same* device.
    */
    virtual void* rawPtr() = 0;
    virtual const void* rawPtr() const = 0;

    /**
     * Return the host pointer.
     * If the implementation does not have a valid host pointer, nullptr is returned.
    */
    virtual void* hostPtr() { return nullptr; };
    virtual const void* hostPtr() const { return nullptr; };
    /**
     * Get the device pointer with an offset (in number of elements).
    */
    virtual void* getRawPtr(NbElts_t idx) = 0;

Olivier BICHLER's avatar
Olivier BICHLER committed
     * Sets the device pointer. The previously owned data is deleted.
     * UNSAFE: directly setting the device pointer may lead to undefined behavior
     * if it does not match the required storage.
     * @param ptr A valid device pointer.
Olivier BICHLER's avatar
Olivier BICHLER committed
     * @param length Storage capacity at the provided pointer
Olivier BICHLER's avatar
Olivier BICHLER committed
    virtual void setRawPtr(void* /*ptr*/, NbElts_t /*length*/)
Cyril Moineau's avatar
Cyril Moineau committed
    {
Olivier BICHLER's avatar
Olivier BICHLER committed
        AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot set raw pointer for backend %s", mBackend);
    virtual std::size_t size() const = 0; // Storage size
Cyril Moineau's avatar
Cyril Moineau committed
    virtual std::size_t scalarSize() const = 0; // Size of one scalar (in bytes)
    constexpr const char *backend() const { return mBackend; }
    virtual ~TensorImpl() = default;
    virtual bool operator==(const TensorImpl &othImpl) const = 0;

    /**
     * Copy from another backend.
     * @param srcImpl Source TensorImpl to copy from.
     * @param length Number of elements of size scalarSize() to copy
    */
    void copyFrom(const TensorImpl& srcImpl, NbElts_t length);

Cyril Moineau's avatar
Cyril Moineau committed
private:
    const char *mBackend;
Cyril Moineau's avatar
Cyril Moineau committed
};

} // namespace Aidge