diff --git a/python_binding/data/pybind_Data.cpp b/python_binding/data/pybind_Data.cpp index c6595360b17ee08eaa82d483987914adc67b60a8..d750c50d984dc0def5a0639ad731f41b83714823 100644 --- a/python_binding/data/pybind_Data.cpp +++ b/python_binding/data/pybind_Data.cpp @@ -32,6 +32,16 @@ void init_Data(py::module& m){ .value("uint64", DataType::UInt64) ; + py::enum_<DataFormat>(m, "dformat") + .value("Default", DataFormat::Default) + .value("NCHW", DataFormat::NCHW) + .value("NHWC", DataFormat::NHWC) + .value("CHWN", DataFormat::CHWN) + .value("NCDHW", DataFormat::NCDHW) + .value("NDHWC", DataFormat::NDHWC) + .value("CDHWN", DataFormat::CDHWN) + ; + py::class_<Data, std::shared_ptr<Data>>(m,"Data"); diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index 83bb4afeacdd6de181fd6738edad2229736854c8..fb33aaaa00a54908dcb5cd8b6d13e235dfd0f5b9 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -84,6 +84,7 @@ void init_Tensor(py::module& m){ .def("grad", &Tensor::grad) .def("set_grad", &Tensor::setGrad) .def("dtype", &Tensor::dataType) + .def("dformat", &Tensor::dataFormat) .def("size", &Tensor::size) .def("capacity", &Tensor::capacity) .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize) diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index 2930383817d1555d51b8bddd8eff6402240e905a..298f3f54af02b68370e0de611353e3b3bc3b37d6 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -30,6 +30,8 @@ void init_GraphView(py::module& m) { :param path: save location :type path: str )mydelimiter") + .def("inputs", (std::vector<std::pair<NodePtr, IOIndex_t>> (GraphView::*)() const) &GraphView::inputs) + .def("outputs", (std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> (GraphView::*)() const) &GraphView::outputs) .def("in_view", (bool (GraphView::*)(const NodePtr&) const) &GraphView::inView) .def("in_view", (bool (GraphView::*)(const std::string&) const) &GraphView::inView) .def("root_node", &GraphView::rootNode)