Skip to content
Snippets Groups Projects
Commit 988de664 authored by Maxence Naud's avatar Maxence Naud
Browse files

Pybind updates

- add Tensor operators +,-,*,/
- add __repr__ funciton to Tensor
- add operator constructor to each python binding
- fix identity_op Base class declaration in binding
parent cc2f1ab8
No related branches found
No related tags found
No related merge requests found
......@@ -73,6 +73,10 @@ void init_Tensor(py::module& m){
(m,"Tensor", py::multiple_inheritance(), py::buffer_protocol());
pyClassTensor.def(py::init<>())
.def(py::self + py::self)
.def(py::self - py::self)
.def(py::self * py::self)
.def(py::self / py::self)
.def("set_datatype", &Tensor::setDataType, py::arg("datatype"), py::arg("copyCast") = true)
.def("set_backend", &Tensor::setBackend, py::arg("name"), py::arg("device") = 0, py::arg("copyFrom") = true)
.def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims)
......@@ -89,6 +93,9 @@ void init_Tensor(py::module& m){
.def("__str__", [](Tensor& b) {
return b.toString();
})
.def("__repr__", [](Tensor& b) {
return "Tensor(dtype = " + std::string(EnumStrings<DataType>::data[static_cast<int>(b.dataType())]) + ",\n" + b.toString() + ")";
})
.def("__len__", [](Tensor& b) -> size_t{
return b.size();
})
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment