From 53f2e8b82fe030c1c03a20f185c08ba336f040a4 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Wed, 20 Sep 2023 06:18:23 +0000 Subject: [PATCH] [Tensor] Add get & set method at Tensor level. --- include/aidge/backend/TensorImpl.hpp | 3 ++ include/aidge/data/Tensor.hpp | 49 +++++++++++++++++---------- python_binding/data/pybind_Tensor.cpp | 37 ++++++++++++++------ 3 files changed, 61 insertions(+), 28 deletions(-) diff --git a/include/aidge/backend/TensorImpl.hpp b/include/aidge/backend/TensorImpl.hpp index c56f66fc0..dfe3d932a 100644 --- a/include/aidge/backend/TensorImpl.hpp +++ b/include/aidge/backend/TensorImpl.hpp @@ -27,6 +27,9 @@ public: { printf("Cannot set raw pointer for backend %s\n", mBackend); }; + + virtual void* getRaw(std::size_t /*idx*/)=0; + virtual std::size_t scalarSize() const = 0; // Size of one scalar (in bytes) constexpr const char *backend() const { return mBackend; } virtual ~TensorImpl() = default; diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 468f48fea..0a67d73a9 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -446,18 +446,33 @@ class Tensor : public Data, */ bool empty() const { return mDims.empty(); } - template <typename expectedType, std::array<std::size_t, 1>::size_type DIM> - constexpr expectedType &get(std::array<std::size_t, DIM> idx) { - assert(DIM == mDims.size()); - assert(mImpl); - std::size_t unfoldedIdx = 0; - for (std::size_t i = 0; i < DIM - std::size_t(1); ++i) { - unfoldedIdx = (unfoldedIdx + idx[i]) * mDims[i + 1]; - } - unfoldedIdx += idx[DIM - 1]; - return static_cast<expectedType *>(mImpl->rawPtr())[unfoldedIdx]; + template <typename expectedType> + expectedType& get(std::size_t idx){ + // TODO : add assert expected Type compatible with datatype + // TODO : add assert idx < Size + return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx)); + } + + template <typename expectedType> + expectedType& get(std::vector<std::size_t> coordIdx){ + return get<expectedType>(getIdx(coordIdx)); + } + + template <typename expectedType> + void set(std::size_t idx, expectedType value){ + // TODO : add assert expected Type compatible with datatype + // TODO : add assert idx < Size + void* dataPtr = mImpl->getRaw(idx); + std::memcpy(dataPtr, &value, sizeof(expectedType)); } + template <typename expectedType> + void set(std::vector<std::size_t> coordIdx, expectedType value){ + set<expectedType>(getIdx(coordIdx), value); + } + + + std::string toString() { if (dims().empty()) { return "{}"; } std::string res; @@ -565,10 +580,10 @@ class Tensor : public Data, * @param flatIdx 1D index of the value considering a flatten tensor. * @return std::vector<DimSize_t> */ - std::vector<DimSize_t> getCoord(DimSize_t flatIdx){ - std::vector<DimSize_t> coordIdx = {}; - DimSize_t idx = flatIdx; - for (DimSize_t d: mDims){ + std::vector<std::size_t> getCoord(std::size_t flatIdx){ + std::vector<std::size_t> coordIdx = {}; + std::size_t idx = flatIdx; + for (std::size_t d: mDims){ coordIdx.push_back(idx % d); idx/=d; } @@ -581,9 +596,9 @@ class Tensor : public Data, * @param coordIdx Coordinate to an element in the tensor * @return DimSize_t */ - DimSize_t getIdx(std::vector<DimSize_t> coordIdx){ - DimSize_t flatIdx = 0; - DimSize_t stride = 1; + std::size_t getIdx(std::vector<std::size_t> coordIdx){ + std::size_t flatIdx = 0; + std::size_t stride = 1; assert(coordIdx.size() == mDims.size() && "Coordinates does not match number of dimensions"); for(std::size_t i=0; i< mDims.size(); ++i){ assert(coordIdx[i] < mDims[i] && "Coordinates dimensions does not fit the dimensions of the tensor"); diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index 168c2c946..31470e0eb 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -48,7 +48,10 @@ void addCtor(py::class_<Tensor, } return newTensor; - })); + })) + .def("__setitem__", (void (Tensor::*)(std::size_t, T)) &Tensor::set) + .def("__setitem__", (void (Tensor::*)(std::vector<std::size_t>, T)) &Tensor::set) + ; } @@ -84,15 +87,27 @@ void init_Tensor(py::module& m){ return b.size(); }) .def("__getitem__", [](Tensor& b, size_t idx)-> py::object { - // TODO : Should return error if backend not compatible with get if (idx >= b.size()) throw py::index_error(); switch(b.dataType()){ case DataType::Float64: - return py::cast(static_cast<double*>(b.getImpl()->rawPtr())[idx]); + return py::cast(b.get<double>(idx)); case DataType::Float32: - return py::cast(static_cast<float*>(b.getImpl()->rawPtr())[idx]); + return py::cast(b.get<float>(idx)); case DataType::Int32: - return py::cast(static_cast<int*>(b.getImpl()->rawPtr())[idx]); + return py::cast(b.get<int>(idx)); + default: + return py::none(); + } + }) + .def("__getitem__", [](Tensor& b, std::vector<size_t> coordIdx)-> py::object { + if (b.getIdx(coordIdx) >= b.size()) throw py::index_error(); + switch(b.dataType()){ + case DataType::Float64: + return py::cast(b.get<double>(coordIdx)); + case DataType::Float32: + return py::cast(b.get<float>(coordIdx)); + case DataType::Int32: + return py::cast(b.get<int>(coordIdx)); default: return py::none(); } @@ -128,12 +143,12 @@ void init_Tensor(py::module& m){ } return py::buffer_info( - tensorImpl->rawPtr(), /* Pointer to buffer */ - tensorImpl->scalarSize(), /* Size of one scalar */ - dataFormatDescriptor, /* Python struct-style format descriptor */ - b.nbDims(), /* Number of dimensions */ - dims, /* Buffer dimensions */ - strides /* Strides (in bytes) for each index */ + tensorImpl->rawPtr(), /* Pointer to buffer */ + tensorImpl->scalarSize(), /* Size of one scalar */ + dataFormatDescriptor, /* Python struct-style format descriptor */ + b.nbDims(), /* Number of dimensions */ + dims, /* Buffer dimensions */ + strides /* Strides (in bytes) for each index */ ); }); -- GitLab