From 4a1682e7c39e669427c64dfac125f34172620f43 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Mon, 3 Jun 2024 08:41:41 +0000
Subject: [PATCH] Pybind updates

- add Tensor operators +,-,*,/
- add __repr__ funciton to Tensor
- add operator constructor to each python binding
- fix identity_op Base class declaration in binding
---
 python_binding/data/pybind_Tensor.cpp | 7 +++++++
 1 file changed, 7 insertions(+)

diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp
index 3c2120565..005175ab6 100644
--- a/python_binding/data/pybind_Tensor.cpp
+++ b/python_binding/data/pybind_Tensor.cpp
@@ -73,6 +73,10 @@ void init_Tensor(py::module& m){
         (m,"Tensor", py::multiple_inheritance(), py::buffer_protocol());
 
     pyClassTensor.def(py::init<>())
+    .def(py::self + py::self)
+    .def(py::self - py::self)
+    .def(py::self * py::self)
+    .def(py::self / py::self)
     .def("set_datatype", &Tensor::setDataType, py::arg("datatype"), py::arg("copyCast") = true)
     .def("set_backend", &Tensor::setBackend, py::arg("name"), py::arg("device") = 0, py::arg("copyFrom") = true)
     .def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims)
@@ -89,6 +93,9 @@ void init_Tensor(py::module& m){
     .def("__str__", [](Tensor& b) {
         return b.toString();
     })
+    .def("__repr__", [](Tensor& b) {
+        return "Tensor(dtype = " + std::string(EnumStrings<DataType>::data[static_cast<int>(b.dataType())]) + ",\n" + b.toString() + ")";
+    })
     .def("__len__", [](Tensor& b) -> size_t{
         return b.size();
     })
-- 
GitLab