diff --git a/include/aidge/backend/TensorImpl.hpp b/include/aidge/backend/TensorImpl.hpp index c56f66fc0b827ccccd9749b9880507dbf48c8179..dfe3d932ac68929acfd26ecf7126e07c4707bcfc 100644 --- a/include/aidge/backend/TensorImpl.hpp +++ b/include/aidge/backend/TensorImpl.hpp @@ -27,6 +27,9 @@ public: { printf("Cannot set raw pointer for backend %s\n", mBackend); }; + + virtual void* getRaw(std::size_t /*idx*/)=0; + virtual std::size_t scalarSize() const = 0; // Size of one scalar (in bytes) constexpr const char *backend() const { return mBackend; } virtual ~TensorImpl() = default; diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 468f48feaecc62ee10cd980ce42f18c99d9bc549..0a67d73a9573c561f334e74fc64efa0e527d115b 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -446,18 +446,33 @@ class Tensor : public Data, */ bool empty() const { return mDims.empty(); } - template <typename expectedType, std::array<std::size_t, 1>::size_type DIM> - constexpr expectedType &get(std::array<std::size_t, DIM> idx) { - assert(DIM == mDims.size()); - assert(mImpl); - std::size_t unfoldedIdx = 0; - for (std::size_t i = 0; i < DIM - std::size_t(1); ++i) { - unfoldedIdx = (unfoldedIdx + idx[i]) * mDims[i + 1]; - } - unfoldedIdx += idx[DIM - 1]; - return static_cast<expectedType *>(mImpl->rawPtr())[unfoldedIdx]; + template <typename expectedType> + expectedType& get(std::size_t idx){ + // TODO : add assert expected Type compatible with datatype + // TODO : add assert idx < Size + return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx)); + } + + template <typename expectedType> + expectedType& get(std::vector<std::size_t> coordIdx){ + return get<expectedType>(getIdx(coordIdx)); + } + + template <typename expectedType> + void set(std::size_t idx, expectedType value){ + // TODO : add assert expected Type compatible with datatype + // TODO : add assert idx < Size + void* dataPtr = mImpl->getRaw(idx); + std::memcpy(dataPtr, &value, sizeof(expectedType)); } + template <typename expectedType> + void set(std::vector<std::size_t> coordIdx, expectedType value){ + set<expectedType>(getIdx(coordIdx), value); + } + + + std::string toString() { if (dims().empty()) { return "{}"; } std::string res; @@ -565,10 +580,10 @@ class Tensor : public Data, * @param flatIdx 1D index of the value considering a flatten tensor. * @return std::vector<DimSize_t> */ - std::vector<DimSize_t> getCoord(DimSize_t flatIdx){ - std::vector<DimSize_t> coordIdx = {}; - DimSize_t idx = flatIdx; - for (DimSize_t d: mDims){ + std::vector<std::size_t> getCoord(std::size_t flatIdx){ + std::vector<std::size_t> coordIdx = {}; + std::size_t idx = flatIdx; + for (std::size_t d: mDims){ coordIdx.push_back(idx % d); idx/=d; } @@ -581,9 +596,9 @@ class Tensor : public Data, * @param coordIdx Coordinate to an element in the tensor * @return DimSize_t */ - DimSize_t getIdx(std::vector<DimSize_t> coordIdx){ - DimSize_t flatIdx = 0; - DimSize_t stride = 1; + std::size_t getIdx(std::vector<std::size_t> coordIdx){ + std::size_t flatIdx = 0; + std::size_t stride = 1; assert(coordIdx.size() == mDims.size() && "Coordinates does not match number of dimensions"); for(std::size_t i=0; i< mDims.size(); ++i){ assert(coordIdx[i] < mDims[i] && "Coordinates dimensions does not fit the dimensions of the tensor"); diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index 168c2c946efa297bbc876095fc4274a3df67b21c..31470e0eb2c50b5386b64498f89419801b133d3a 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -48,7 +48,10 @@ void addCtor(py::class_<Tensor, } return newTensor; - })); + })) + .def("__setitem__", (void (Tensor::*)(std::size_t, T)) &Tensor::set) + .def("__setitem__", (void (Tensor::*)(std::vector<std::size_t>, T)) &Tensor::set) + ; } @@ -84,15 +87,27 @@ void init_Tensor(py::module& m){ return b.size(); }) .def("__getitem__", [](Tensor& b, size_t idx)-> py::object { - // TODO : Should return error if backend not compatible with get if (idx >= b.size()) throw py::index_error(); switch(b.dataType()){ case DataType::Float64: - return py::cast(static_cast<double*>(b.getImpl()->rawPtr())[idx]); + return py::cast(b.get<double>(idx)); case DataType::Float32: - return py::cast(static_cast<float*>(b.getImpl()->rawPtr())[idx]); + return py::cast(b.get<float>(idx)); case DataType::Int32: - return py::cast(static_cast<int*>(b.getImpl()->rawPtr())[idx]); + return py::cast(b.get<int>(idx)); + default: + return py::none(); + } + }) + .def("__getitem__", [](Tensor& b, std::vector<size_t> coordIdx)-> py::object { + if (b.getIdx(coordIdx) >= b.size()) throw py::index_error(); + switch(b.dataType()){ + case DataType::Float64: + return py::cast(b.get<double>(coordIdx)); + case DataType::Float32: + return py::cast(b.get<float>(coordIdx)); + case DataType::Int32: + return py::cast(b.get<int>(coordIdx)); default: return py::none(); } @@ -128,12 +143,12 @@ void init_Tensor(py::module& m){ } return py::buffer_info( - tensorImpl->rawPtr(), /* Pointer to buffer */ - tensorImpl->scalarSize(), /* Size of one scalar */ - dataFormatDescriptor, /* Python struct-style format descriptor */ - b.nbDims(), /* Number of dimensions */ - dims, /* Buffer dimensions */ - strides /* Strides (in bytes) for each index */ + tensorImpl->rawPtr(), /* Pointer to buffer */ + tensorImpl->scalarSize(), /* Size of one scalar */ + dataFormatDescriptor, /* Python struct-style format descriptor */ + b.nbDims(), /* Number of dimensions */ + dims, /* Buffer dimensions */ + strides /* Strides (in bytes) for each index */ ); });