From 7fc2dd9535f2addd2f162f2afa6774a83aa51cd0 Mon Sep 17 00:00:00 2001 From: thibault allenet <thibault.allenet@cea.fr> Date: Mon, 12 Feb 2024 15:52:32 +0000 Subject: [PATCH] Update pybind_trensor with the ctr by dims instead of number of elements --- python_binding/data/pybind_Tensor.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index c948b1ffd..b09570792 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -30,7 +30,7 @@ void addCtor(py::class_<Tensor, Data, Registrable<Tensor, std::tuple<std::string, DataType>, - std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>>& mTensor){ + std::shared_ptr<TensorImpl>(DeviceIdx_t device, std::vector<DimSize_t> dims)>>& mTensor){ mTensor.def(py::init([]( py::array_t<T, py::array::c_style | py::array::forcecast> b) { /* Request a buffer descriptor from Python */ py::buffer_info info = b.request(); @@ -58,16 +58,16 @@ void addCtor(py::class_<Tensor, void init_Tensor(py::module& m){ py::class_<Registrable<Tensor, std::tuple<std::string, DataType>, - std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>, + std::shared_ptr<TensorImpl>(DeviceIdx_t device, std::vector<DimSize_t> dims)>, std::shared_ptr<Registrable<Tensor, std::tuple<std::string, DataType>, - std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>>>(m,"TensorRegistrable"); + std::shared_ptr<TensorImpl>(DeviceIdx_t device, std::vector<DimSize_t> dims)>>>(m,"TensorRegistrable"); py::class_<Tensor, std::shared_ptr<Tensor>, Data, Registrable<Tensor, std::tuple<std::string, DataType>, - std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>> pyClassTensor + std::shared_ptr<TensorImpl>(DeviceIdx_t device, std::vector<DimSize_t> dims)>> pyClassTensor (m,"Tensor", py::multiple_inheritance(), py::buffer_protocol()); pyClassTensor.def(py::init<>()) -- GitLab