Skip to content
Snippets Groups Projects
Commit 53f2e8b8 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

[Tensor] Add get & set method at Tensor level.

parent 4f3d4fbc
No related branches found
No related tags found
No related merge requests found
......@@ -27,6 +27,9 @@ public:
{
printf("Cannot set raw pointer for backend %s\n", mBackend);
};
virtual void* getRaw(std::size_t /*idx*/)=0;
virtual std::size_t scalarSize() const = 0; // Size of one scalar (in bytes)
constexpr const char *backend() const { return mBackend; }
virtual ~TensorImpl() = default;
......
......@@ -446,18 +446,33 @@ class Tensor : public Data,
*/
bool empty() const { return mDims.empty(); }
template <typename expectedType, std::array<std::size_t, 1>::size_type DIM>
constexpr expectedType &get(std::array<std::size_t, DIM> idx) {
assert(DIM == mDims.size());
assert(mImpl);
std::size_t unfoldedIdx = 0;
for (std::size_t i = 0; i < DIM - std::size_t(1); ++i) {
unfoldedIdx = (unfoldedIdx + idx[i]) * mDims[i + 1];
}
unfoldedIdx += idx[DIM - 1];
return static_cast<expectedType *>(mImpl->rawPtr())[unfoldedIdx];
template <typename expectedType>
expectedType& get(std::size_t idx){
// TODO : add assert expected Type compatible with datatype
// TODO : add assert idx < Size
return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx));
}
template <typename expectedType>
expectedType& get(std::vector<std::size_t> coordIdx){
return get<expectedType>(getIdx(coordIdx));
}
template <typename expectedType>
void set(std::size_t idx, expectedType value){
// TODO : add assert expected Type compatible with datatype
// TODO : add assert idx < Size
void* dataPtr = mImpl->getRaw(idx);
std::memcpy(dataPtr, &value, sizeof(expectedType));
}
template <typename expectedType>
void set(std::vector<std::size_t> coordIdx, expectedType value){
set<expectedType>(getIdx(coordIdx), value);
}
std::string toString() {
if (dims().empty()) { return "{}"; }
std::string res;
......@@ -565,10 +580,10 @@ class Tensor : public Data,
* @param flatIdx 1D index of the value considering a flatten tensor.
* @return std::vector<DimSize_t>
*/
std::vector<DimSize_t> getCoord(DimSize_t flatIdx){
std::vector<DimSize_t> coordIdx = {};
DimSize_t idx = flatIdx;
for (DimSize_t d: mDims){
std::vector<std::size_t> getCoord(std::size_t flatIdx){
std::vector<std::size_t> coordIdx = {};
std::size_t idx = flatIdx;
for (std::size_t d: mDims){
coordIdx.push_back(idx % d);
idx/=d;
}
......@@ -581,9 +596,9 @@ class Tensor : public Data,
* @param coordIdx Coordinate to an element in the tensor
* @return DimSize_t
*/
DimSize_t getIdx(std::vector<DimSize_t> coordIdx){
DimSize_t flatIdx = 0;
DimSize_t stride = 1;
std::size_t getIdx(std::vector<std::size_t> coordIdx){
std::size_t flatIdx = 0;
std::size_t stride = 1;
assert(coordIdx.size() == mDims.size() && "Coordinates does not match number of dimensions");
for(std::size_t i=0; i< mDims.size(); ++i){
assert(coordIdx[i] < mDims[i] && "Coordinates dimensions does not fit the dimensions of the tensor");
......
......@@ -48,7 +48,10 @@ void addCtor(py::class_<Tensor,
}
return newTensor;
}));
}))
.def("__setitem__", (void (Tensor::*)(std::size_t, T)) &Tensor::set)
.def("__setitem__", (void (Tensor::*)(std::vector<std::size_t>, T)) &Tensor::set)
;
}
......@@ -84,15 +87,27 @@ void init_Tensor(py::module& m){
return b.size();
})
.def("__getitem__", [](Tensor& b, size_t idx)-> py::object {
// TODO : Should return error if backend not compatible with get
if (idx >= b.size()) throw py::index_error();
switch(b.dataType()){
case DataType::Float64:
return py::cast(static_cast<double*>(b.getImpl()->rawPtr())[idx]);
return py::cast(b.get<double>(idx));
case DataType::Float32:
return py::cast(static_cast<float*>(b.getImpl()->rawPtr())[idx]);
return py::cast(b.get<float>(idx));
case DataType::Int32:
return py::cast(static_cast<int*>(b.getImpl()->rawPtr())[idx]);
return py::cast(b.get<int>(idx));
default:
return py::none();
}
})
.def("__getitem__", [](Tensor& b, std::vector<size_t> coordIdx)-> py::object {
if (b.getIdx(coordIdx) >= b.size()) throw py::index_error();
switch(b.dataType()){
case DataType::Float64:
return py::cast(b.get<double>(coordIdx));
case DataType::Float32:
return py::cast(b.get<float>(coordIdx));
case DataType::Int32:
return py::cast(b.get<int>(coordIdx));
default:
return py::none();
}
......@@ -128,12 +143,12 @@ void init_Tensor(py::module& m){
}
return py::buffer_info(
tensorImpl->rawPtr(), /* Pointer to buffer */
tensorImpl->scalarSize(), /* Size of one scalar */
dataFormatDescriptor, /* Python struct-style format descriptor */
b.nbDims(), /* Number of dimensions */
dims, /* Buffer dimensions */
strides /* Strides (in bytes) for each index */
tensorImpl->rawPtr(), /* Pointer to buffer */
tensorImpl->scalarSize(), /* Size of one scalar */
dataFormatDescriptor, /* Python struct-style format descriptor */
b.nbDims(), /* Number of dimensions */
dims, /* Buffer dimensions */
strides /* Strides (in bytes) for each index */
);
});
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment