diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index c948b1ffd414fd1b421c9a842a16982501b5b2e0..b09570792b0737376c7d477fa7addd477a212bd8 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -30,7 +30,7 @@ void addCtor(py::class_<Tensor, Data, Registrable<Tensor, std::tuple<std::string, DataType>, - std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>>& mTensor){ + std::shared_ptr<TensorImpl>(DeviceIdx_t device, std::vector<DimSize_t> dims)>>& mTensor){ mTensor.def(py::init([]( py::array_t<T, py::array::c_style | py::array::forcecast> b) { /* Request a buffer descriptor from Python */ py::buffer_info info = b.request(); @@ -58,16 +58,16 @@ void addCtor(py::class_<Tensor, void init_Tensor(py::module& m){ py::class_<Registrable<Tensor, std::tuple<std::string, DataType>, - std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>, + std::shared_ptr<TensorImpl>(DeviceIdx_t device, std::vector<DimSize_t> dims)>, std::shared_ptr<Registrable<Tensor, std::tuple<std::string, DataType>, - std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>>>(m,"TensorRegistrable"); + std::shared_ptr<TensorImpl>(DeviceIdx_t device, std::vector<DimSize_t> dims)>>>(m,"TensorRegistrable"); py::class_<Tensor, std::shared_ptr<Tensor>, Data, Registrable<Tensor, std::tuple<std::string, DataType>, - std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>> pyClassTensor + std::shared_ptr<TensorImpl>(DeviceIdx_t device, std::vector<DimSize_t> dims)>> pyClassTensor (m,"Tensor", py::multiple_inheritance(), py::buffer_protocol()); pyClassTensor.def(py::init<>())