Skip to content
Snippets Groups Projects
Commit a92ad137 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Add barious methods to python binding.

parent 36121767
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!163Export refactor
......@@ -32,6 +32,16 @@ void init_Data(py::module& m){
.value("uint64", DataType::UInt64)
;
py::enum_<DataFormat>(m, "dformat")
.value("Default", DataFormat::Default)
.value("NCHW", DataFormat::NCHW)
.value("NHWC", DataFormat::NHWC)
.value("CHWN", DataFormat::CHWN)
.value("NCDHW", DataFormat::NCDHW)
.value("NDHWC", DataFormat::NDHWC)
.value("CDHWN", DataFormat::CDHWN)
;
py::class_<Data, std::shared_ptr<Data>>(m,"Data");
......
......@@ -84,6 +84,7 @@ void init_Tensor(py::module& m){
.def("grad", &Tensor::grad)
.def("set_grad", &Tensor::setGrad)
.def("dtype", &Tensor::dataType)
.def("dformat", &Tensor::dataFormat)
.def("size", &Tensor::size)
.def("capacity", &Tensor::capacity)
.def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize)
......
......@@ -30,6 +30,8 @@ void init_GraphView(py::module& m) {
:param path: save location
:type path: str
)mydelimiter")
.def("inputs", (std::vector<std::pair<NodePtr, IOIndex_t>> (GraphView::*)() const) &GraphView::inputs)
.def("outputs", (std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> (GraphView::*)() const) &GraphView::outputs)
.def("in_view", (bool (GraphView::*)(const NodePtr&) const) &GraphView::inView)
.def("in_view", (bool (GraphView::*)(const std::string&) const) &GraphView::inView)
.def("root_node", &GraphView::rootNode)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment