Skip to content
Snippets Groups Projects
Commit edc2e9f0 authored by Maxence Naud's avatar Maxence Naud
Browse files

Add Tensor python constructors from shape and single value

parent 5a62267c
No related branches found
No related tags found
2 merge requests!212Version 0.3.0,!157Chore_26_compare-layer-import-export-onnx
Pipeline #50318 passed
...@@ -51,6 +51,7 @@ void addCtor(py::class_<Tensor, ...@@ -51,6 +51,7 @@ void addCtor(py::class_<Tensor,
return newTensor; return newTensor;
}), py::arg("array"), py::arg("backend")="cpu") }), py::arg("array"), py::arg("backend")="cpu")
.def(py::init<T>(), py::arg("val"))
.def("__setitem__", (void (Tensor::*)(std::size_t, T)) &Tensor::set) .def("__setitem__", (void (Tensor::*)(std::size_t, T)) &Tensor::set)
.def("__setitem__", (void (Tensor::*)(std::vector<std::size_t>, T)) &Tensor::set) .def("__setitem__", (void (Tensor::*)(std::vector<std::size_t>, T)) &Tensor::set)
; ;
...@@ -73,6 +74,7 @@ void init_Tensor(py::module& m){ ...@@ -73,6 +74,7 @@ void init_Tensor(py::module& m){
(m,"Tensor", py::multiple_inheritance(), py::buffer_protocol()); (m,"Tensor", py::multiple_inheritance(), py::buffer_protocol());
pyClassTensor.def(py::init<>()) pyClassTensor.def(py::init<>())
.def(py::init<const std::vector<std::size_t>&>(), py::arg("dims"))
.def(py::self + py::self) .def(py::self + py::self)
.def(py::self - py::self) .def(py::self - py::self)
.def(py::self * py::self) .def(py::self * py::self)
...@@ -86,7 +88,7 @@ void init_Tensor(py::module& m){ ...@@ -86,7 +88,7 @@ void init_Tensor(py::module& m){
.def("dtype", &Tensor::dataType) .def("dtype", &Tensor::dataType)
.def("size", &Tensor::size) .def("size", &Tensor::size)
.def("capacity", &Tensor::capacity) .def("capacity", &Tensor::capacity)
.def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize) .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize, py::arg("dims"), py::arg("strides") = std::vector<DimSize_t>())
.def("has_impl", &Tensor::hasImpl) .def("has_impl", &Tensor::hasImpl)
.def("get_coord", &Tensor::getCoord) .def("get_coord", &Tensor::getCoord)
.def("get_idx", &Tensor::getIdx) .def("get_idx", &Tensor::getIdx)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment