From 9b481505e3b28cde900f1086b7e759f64be9ce37 Mon Sep 17 00:00:00 2001
From: idealbuq <iryna.dealbuquerquesilva@cea.fr>
Date: Thu, 12 Sep 2024 09:51:16 +0000
Subject: [PATCH] Added pybindings for dataformat and get_ranked_nodes

---
 python_binding/data/pybind_Data.cpp       | 24 +++++++++++++++++++++++
 python_binding/data/pybind_Tensor.cpp     |  1 +
 python_binding/graph/pybind_GraphView.cpp |  3 +++
 3 files changed, 28 insertions(+)

diff --git a/python_binding/data/pybind_Data.cpp b/python_binding/data/pybind_Data.cpp
index 1d4eae077..4e17328e5 100644
--- a/python_binding/data/pybind_Data.cpp
+++ b/python_binding/data/pybind_Data.cpp
@@ -53,6 +53,30 @@ void init_Data(py::module& m){
     e_dtype.def("__str__", [enum_names](const DataType& dtype) {
         return enum_names[static_cast<int>(dtype)];
     }, py::prepend());;
+    // TODO : extend with more values !
+    py::enum_<DataType>(m, "dtype")
+    .value("float64", DataType::Float64)
+    .value("float32", DataType::Float32)
+    .value("float16", DataType::Float16)
+    .value("int8", DataType::Int8)
+    .value("int16", DataType::Int16)
+    .value("int32", DataType::Int32)
+    .value("int64", DataType::Int64)
+    .value("uint8", DataType::UInt8)
+    .value("uint16", DataType::UInt16)
+    .value("uint32", DataType::UInt32)
+    .value("uint64", DataType::UInt64)
+    ;
+
+    py::enum_<DataFormat>(m, "dformat")
+    .value("Default", DataFormat::Default)
+    .value("NCHW", DataFormat::NCHW) // default
+    .value("NHWC", DataFormat::NHWC) 
+    .value("CHWN", DataFormat::CHWN)
+    .value("NCDHW", DataFormat::NCDHW) 
+    .value("NDHWC", DataFormat::NDHWC)
+    .value("CDHWN", DataFormat::CDHWN)
+    ;
 
     py::class_<Data, std::shared_ptr<Data>>(m,"Data");
 
diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp
index 73a55d461..cdbcf3dcc 100644
--- a/python_binding/data/pybind_Tensor.cpp
+++ b/python_binding/data/pybind_Tensor.cpp
@@ -322,6 +322,7 @@ void init_Tensor(py::module& m){
     .def("grad", &Tensor::grad)
     .def("set_grad", &Tensor::setGrad)
     .def("dtype", &Tensor::dataType)
+    .def("dformat", &Tensor::dataFormat)
     .def("size", &Tensor::size)
     .def("capacity", &Tensor::capacity)
     .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize, py::arg("dims"), py::arg("strides") = std::vector<DimSize_t>())
diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp
index 293038381..129b9d061 100644
--- a/python_binding/graph/pybind_GraphView.cpp
+++ b/python_binding/graph/pybind_GraphView.cpp
@@ -142,6 +142,9 @@ void init_GraphView(py::module& m) {
           //                return py::none();
           //           }
           //      })
+          .def("get_ranked_nodes", &GraphView::getRankedNodes)
+          .def("set_dataformat", &GraphView::setDataFormat, py::arg("dataformat"))
+          
             ;
 
      m.def("get_connected_graph_view", &getConnectedGraphView);
-- 
GitLab