Skip to content
Snippets Groups Projects
Commit 7fc2dd95 authored by Thibault Allenet's avatar Thibault Allenet
Browse files

Update pybind_trensor with the ctr by dims instead of number of elements

parent 8cfcd350
No related branches found
No related tags found
No related merge requests found
......@@ -30,7 +30,7 @@ void addCtor(py::class_<Tensor,
Data,
Registrable<Tensor,
std::tuple<std::string, DataType>,
std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>>& mTensor){
std::shared_ptr<TensorImpl>(DeviceIdx_t device, std::vector<DimSize_t> dims)>>& mTensor){
mTensor.def(py::init([]( py::array_t<T, py::array::c_style | py::array::forcecast> b) {
/* Request a buffer descriptor from Python */
py::buffer_info info = b.request();
......@@ -58,16 +58,16 @@ void addCtor(py::class_<Tensor,
void init_Tensor(py::module& m){
py::class_<Registrable<Tensor,
std::tuple<std::string, DataType>,
std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>,
std::shared_ptr<TensorImpl>(DeviceIdx_t device, std::vector<DimSize_t> dims)>,
std::shared_ptr<Registrable<Tensor,
std::tuple<std::string, DataType>,
std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>>>(m,"TensorRegistrable");
std::shared_ptr<TensorImpl>(DeviceIdx_t device, std::vector<DimSize_t> dims)>>>(m,"TensorRegistrable");
py::class_<Tensor, std::shared_ptr<Tensor>,
Data,
Registrable<Tensor,
std::tuple<std::string, DataType>,
std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>> pyClassTensor
std::shared_ptr<TensorImpl>(DeviceIdx_t device, std::vector<DimSize_t> dims)>> pyClassTensor
(m,"Tensor", py::multiple_inheritance(), py::buffer_protocol());
pyClassTensor.def(py::init<>())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment