From a92ad1379193a1b82247b9a0e3d7649d2544eec2 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Tue, 9 Jul 2024 09:24:43 +0000 Subject: [PATCH] Add barious methods to python binding. --- python_binding/data/pybind_Data.cpp | 10 ++++++++++ python_binding/data/pybind_Tensor.cpp | 1 + python_binding/graph/pybind_GraphView.cpp | 2 ++ 3 files changed, 13 insertions(+) diff --git a/python_binding/data/pybind_Data.cpp b/python_binding/data/pybind_Data.cpp index c6595360b..d750c50d9 100644 --- a/python_binding/data/pybind_Data.cpp +++ b/python_binding/data/pybind_Data.cpp @@ -32,6 +32,16 @@ void init_Data(py::module& m){ .value("uint64", DataType::UInt64) ; + py::enum_<DataFormat>(m, "dformat") + .value("Default", DataFormat::Default) + .value("NCHW", DataFormat::NCHW) + .value("NHWC", DataFormat::NHWC) + .value("CHWN", DataFormat::CHWN) + .value("NCDHW", DataFormat::NCDHW) + .value("NDHWC", DataFormat::NDHWC) + .value("CDHWN", DataFormat::CDHWN) + ; + py::class_<Data, std::shared_ptr<Data>>(m,"Data"); diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index 83bb4afea..fb33aaaa0 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -84,6 +84,7 @@ void init_Tensor(py::module& m){ .def("grad", &Tensor::grad) .def("set_grad", &Tensor::setGrad) .def("dtype", &Tensor::dataType) + .def("dformat", &Tensor::dataFormat) .def("size", &Tensor::size) .def("capacity", &Tensor::capacity) .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize) diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index 293038381..298f3f54a 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -30,6 +30,8 @@ void init_GraphView(py::module& m) { :param path: save location :type path: str )mydelimiter") + .def("inputs", (std::vector<std::pair<NodePtr, IOIndex_t>> (GraphView::*)() const) &GraphView::inputs) + .def("outputs", (std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> (GraphView::*)() const) &GraphView::outputs) .def("in_view", (bool (GraphView::*)(const NodePtr&) const) &GraphView::inView) .def("in_view", (bool (GraphView::*)(const std::string&) const) &GraphView::inView) .def("root_node", &GraphView::rootNode) -- GitLab