From 9b481505e3b28cde900f1086b7e759f64be9ce37 Mon Sep 17 00:00:00 2001 From: idealbuq <iryna.dealbuquerquesilva@cea.fr> Date: Thu, 12 Sep 2024 09:51:16 +0000 Subject: [PATCH] Added pybindings for dataformat and get_ranked_nodes --- python_binding/data/pybind_Data.cpp | 24 +++++++++++++++++++++++ python_binding/data/pybind_Tensor.cpp | 1 + python_binding/graph/pybind_GraphView.cpp | 3 +++ 3 files changed, 28 insertions(+) diff --git a/python_binding/data/pybind_Data.cpp b/python_binding/data/pybind_Data.cpp index 1d4eae077..4e17328e5 100644 --- a/python_binding/data/pybind_Data.cpp +++ b/python_binding/data/pybind_Data.cpp @@ -53,6 +53,30 @@ void init_Data(py::module& m){ e_dtype.def("__str__", [enum_names](const DataType& dtype) { return enum_names[static_cast<int>(dtype)]; }, py::prepend());; + // TODO : extend with more values ! + py::enum_<DataType>(m, "dtype") + .value("float64", DataType::Float64) + .value("float32", DataType::Float32) + .value("float16", DataType::Float16) + .value("int8", DataType::Int8) + .value("int16", DataType::Int16) + .value("int32", DataType::Int32) + .value("int64", DataType::Int64) + .value("uint8", DataType::UInt8) + .value("uint16", DataType::UInt16) + .value("uint32", DataType::UInt32) + .value("uint64", DataType::UInt64) + ; + + py::enum_<DataFormat>(m, "dformat") + .value("Default", DataFormat::Default) + .value("NCHW", DataFormat::NCHW) // default + .value("NHWC", DataFormat::NHWC) + .value("CHWN", DataFormat::CHWN) + .value("NCDHW", DataFormat::NCDHW) + .value("NDHWC", DataFormat::NDHWC) + .value("CDHWN", DataFormat::CDHWN) + ; py::class_<Data, std::shared_ptr<Data>>(m,"Data"); diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index 73a55d461..cdbcf3dcc 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -322,6 +322,7 @@ void init_Tensor(py::module& m){ .def("grad", &Tensor::grad) .def("set_grad", &Tensor::setGrad) .def("dtype", &Tensor::dataType) + .def("dformat", &Tensor::dataFormat) .def("size", &Tensor::size) .def("capacity", &Tensor::capacity) .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize, py::arg("dims"), py::arg("strides") = std::vector<DimSize_t>()) diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index 293038381..129b9d061 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -142,6 +142,9 @@ void init_GraphView(py::module& m) { // return py::none(); // } // }) + .def("get_ranked_nodes", &GraphView::getRankedNodes) + .def("set_dataformat", &GraphView::setDataFormat, py::arg("dataformat")) + ; m.def("get_connected_graph_view", &getConnectedGraphView); -- GitLab