Skip to content
Snippets Groups Projects
TensorImpl.hpp 2.44 KiB
Newer Older
#ifndef AIDGE_CPU_DATA_TENSORIMPL_H_
#define AIDGE_CPU_DATA_TENSORIMPL_H_

#include "aidge/backend/TensorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"

namespace Aidge {
template <class T>
class TensorImpl_cpu : public TensorImpl {
   private:
    const Tensor &mTensor;  // Impl needs to access Tensor information, but is not
                            // supposed to change it!
    std::vector<T> mData;

   public:
    static constexpr const char *Backend = "cpu";

    TensorImpl_cpu(const Tensor &tensor) : TensorImpl(Backend), mTensor(tensor) {}

    bool operator==(const TensorImpl &otherImpl) const override final {
        std::size_t i = 0;
        for (; i < mTensor.size() &&
               mData[i] == reinterpret_cast<const TensorImpl_cpu<T> &>(otherImpl).data()[i];
             ++i) {
        }
        return i == mTensor.size();
    }

    static std::unique_ptr<TensorImpl_cpu> create(const Tensor &tensor) {
        return std::make_unique<TensorImpl_cpu<T>>(tensor);
    }

    // native interface
    const std::vector<T> &data() const { return mData; }

    std::size_t scalarSize() const override { return sizeof(T); }

    void copy(const void *src, NbElts_t length) override {
        std::copy(static_cast<const T *>(src), static_cast<const T *>(src) + length,
                  static_cast<T *>(rawPtr()));
    }

    void *rawPtr() override {
        lazyInit(mData);
        return mData.data();
    };

   void* getRaw(std::size_t idx){
Maxence Naud's avatar
Maxence Naud committed
       return  static_cast<void*>(static_cast<T *>(rawPtr()) + idx);
    virtual ~TensorImpl_cpu() = default;

    void setRawPtr(void *ptr) override final {
        T *newPtr = static_cast<T *>(ptr);
        mData = std::vector<T>(newPtr, newPtr + mTensor.size());
    };

   private:
    void lazyInit(std::vector<T> &data) {
        assert(mTensor.dataType() == NativeType<T>::type);

        if (data.size() != mTensor.size()) data.resize(mTensor.size());
    }
};

namespace {
static Registrar<Tensor> registrarTensorImpl_cpu_Float64(
        {"cpu", DataType::Float64}, Aidge::TensorImpl_cpu<double>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Float32(
        {"cpu", DataType::Float32}, Aidge::TensorImpl_cpu<float>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Int32(
        {"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int>::create);
}  // namespace
}  // namespace Aidge

#endif /* AIDGE_CPU_DATA_TENSORIMPL_H_ */