From a92ad1379193a1b82247b9a0e3d7649d2544eec2 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Tue, 9 Jul 2024 09:24:43 +0000
Subject: [PATCH] Add barious methods to python binding.

---
 python_binding/data/pybind_Data.cpp       | 10 ++++++++++
 python_binding/data/pybind_Tensor.cpp     |  1 +
 python_binding/graph/pybind_GraphView.cpp |  2 ++
 3 files changed, 13 insertions(+)

diff --git a/python_binding/data/pybind_Data.cpp b/python_binding/data/pybind_Data.cpp
index c6595360b..d750c50d9 100644
--- a/python_binding/data/pybind_Data.cpp
+++ b/python_binding/data/pybind_Data.cpp
@@ -32,6 +32,16 @@ void init_Data(py::module& m){
     .value("uint64", DataType::UInt64)
     ;
 
+    py::enum_<DataFormat>(m, "dformat")
+    .value("Default", DataFormat::Default)
+    .value("NCHW", DataFormat::NCHW)
+    .value("NHWC", DataFormat::NHWC)
+    .value("CHWN", DataFormat::CHWN)
+    .value("NCDHW", DataFormat::NCDHW)
+    .value("NDHWC", DataFormat::NDHWC)
+    .value("CDHWN", DataFormat::CDHWN)
+    ;
+
     py::class_<Data, std::shared_ptr<Data>>(m,"Data");
 
 
diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp
index 83bb4afea..fb33aaaa0 100644
--- a/python_binding/data/pybind_Tensor.cpp
+++ b/python_binding/data/pybind_Tensor.cpp
@@ -84,6 +84,7 @@ void init_Tensor(py::module& m){
     .def("grad", &Tensor::grad)
     .def("set_grad", &Tensor::setGrad)
     .def("dtype", &Tensor::dataType)
+    .def("dformat", &Tensor::dataFormat)
     .def("size", &Tensor::size)
     .def("capacity", &Tensor::capacity)
     .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize)
diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp
index 293038381..298f3f54a 100644
--- a/python_binding/graph/pybind_GraphView.cpp
+++ b/python_binding/graph/pybind_GraphView.cpp
@@ -30,6 +30,8 @@ void init_GraphView(py::module& m) {
           :param path: save location
           :type path: str
           )mydelimiter")
+          .def("inputs", (std::vector<std::pair<NodePtr, IOIndex_t>> (GraphView::*)() const) &GraphView::inputs)
+          .def("outputs", (std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> (GraphView::*)() const) &GraphView::outputs)
           .def("in_view", (bool (GraphView::*)(const NodePtr&) const) &GraphView::inView)
           .def("in_view", (bool (GraphView::*)(const std::string&) const) &GraphView::inView)
           .def("root_node", &GraphView::rootNode)
-- 
GitLab