diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index fe606cfb557042d581e09da7419d80841d1dc2d4..fc800816b7ab012cd87d6af8d451392b39563e4f 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -318,6 +318,7 @@ void init_Tensor(py::module& m){ .def("clone", &Tensor::clone) .def("sqrt", &Tensor::sqrt) .def("set_datatype", &Tensor::setDataType, py::arg("datatype"), py::arg("copyCast") = true) + .def("set_data_format", &Tensor::setDataFormat, py::arg("data_format"), py::arg("copyTrans") = true) .def("set_backend", &Tensor::setBackend, py::arg("name"), py::arg("device") = 0, py::arg("copyFrom") = true) .def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims) .def("grad", &Tensor::grad)