From 7fc2dd9535f2addd2f162f2afa6774a83aa51cd0 Mon Sep 17 00:00:00 2001
From: thibault allenet <thibault.allenet@cea.fr>
Date: Mon, 12 Feb 2024 15:52:32 +0000
Subject: [PATCH] Update pybind_trensor with the ctr by dims instead of number
 of elements

---
 python_binding/data/pybind_Tensor.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp
index c948b1ffd..b09570792 100644
--- a/python_binding/data/pybind_Tensor.cpp
+++ b/python_binding/data/pybind_Tensor.cpp
@@ -30,7 +30,7 @@ void addCtor(py::class_<Tensor,
                         Data,
                         Registrable<Tensor,
                                     std::tuple<std::string, DataType>,
-                                    std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>>& mTensor){
+                                    std::shared_ptr<TensorImpl>(DeviceIdx_t device, std::vector<DimSize_t> dims)>>& mTensor){
     mTensor.def(py::init([]( py::array_t<T, py::array::c_style | py::array::forcecast> b) {
         /* Request a buffer descriptor from Python */
         py::buffer_info info = b.request();
@@ -58,16 +58,16 @@ void addCtor(py::class_<Tensor,
 void init_Tensor(py::module& m){
     py::class_<Registrable<Tensor,
                            std::tuple<std::string, DataType>,
-                           std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>,
+                           std::shared_ptr<TensorImpl>(DeviceIdx_t device, std::vector<DimSize_t> dims)>,
                std::shared_ptr<Registrable<Tensor,
                                            std::tuple<std::string, DataType>,
-                                           std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>>>(m,"TensorRegistrable");
+                                           std::shared_ptr<TensorImpl>(DeviceIdx_t device, std::vector<DimSize_t> dims)>>>(m,"TensorRegistrable");
 
     py::class_<Tensor, std::shared_ptr<Tensor>,
                Data,
                Registrable<Tensor,
                            std::tuple<std::string, DataType>,
-                           std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>> pyClassTensor
+                           std::shared_ptr<TensorImpl>(DeviceIdx_t device, std::vector<DimSize_t> dims)>> pyClassTensor
         (m,"Tensor", py::multiple_inheritance(), py::buffer_protocol());
 
     pyClassTensor.def(py::init<>())
-- 
GitLab