From 3a88f9be397b6b40a3d042154ac034985f1ace23 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Mon, 15 Jan 2024 16:07:40 +0100
Subject: [PATCH] Fixed pybind issues

---
 include/aidge/data/Tensor.hpp         |  2 +-
 python_binding/data/pybind_Tensor.cpp | 12 ++++++------
 2 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp
index a6d0ce341..36947ca7f 100644
--- a/include/aidge/data/Tensor.hpp
+++ b/include/aidge/data/Tensor.hpp
@@ -32,7 +32,7 @@ namespace Aidge {
  * Contains a pointer to an actual contiguous implementation of data.
  */
 class Tensor : public Data,
-               public Registrable<Tensor, std::tuple<std::string, DataType>, std::shared_ptr<TensorImpl>(int device, NbElts_t length)> {
+               public Registrable<Tensor, std::tuple<std::string, DataType>, std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)> {
    private:
     DataType mDataType; /** enum to specify data type. */
     std::vector<DimSize_t> mDims; /** Dimensions of the tensor. */
diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp
index fa109a9af..4f760a65f 100644
--- a/python_binding/data/pybind_Tensor.cpp
+++ b/python_binding/data/pybind_Tensor.cpp
@@ -30,7 +30,7 @@ void addCtor(py::class_<Tensor,
                         Data,
                         Registrable<Tensor,
                                     std::tuple<std::string, DataType>,
-                                    std::unique_ptr<TensorImpl>(const Tensor&)>>& mTensor){
+                                    std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>>& mTensor){
     mTensor.def(py::init([]( py::array_t<T, py::array::c_style | py::array::forcecast> b) {
         /* Request a buffer descriptor from Python */
         py::buffer_info info = b.request();
@@ -58,16 +58,16 @@ void addCtor(py::class_<Tensor,
 void init_Tensor(py::module& m){
     py::class_<Registrable<Tensor,
                            std::tuple<std::string, DataType>,
-                           std::unique_ptr<TensorImpl>(const Tensor&)>,
+                           std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>,
                std::shared_ptr<Registrable<Tensor,
                                            std::tuple<std::string, DataType>,
-                                           std::unique_ptr<TensorImpl>(const Tensor&)>>>(m,"TensorRegistrable");
+                                           std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>>>(m,"TensorRegistrable");
 
     py::class_<Tensor, std::shared_ptr<Tensor>,
                Data,
                Registrable<Tensor,
                            std::tuple<std::string, DataType>,
-                           std::unique_ptr<TensorImpl>(const Tensor&)>> pyClassTensor
+                           std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>> pyClassTensor
         (m,"Tensor", py::multiple_inheritance(), py::buffer_protocol());
 
     pyClassTensor.def(py::init<>())
@@ -76,7 +76,7 @@ void init_Tensor(py::module& m){
     .def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims)
     .def("dtype", &Tensor::dataType)
     .def("size", &Tensor::size)
-    .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&)) &Tensor::resize)
+    .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize)
     .def("has_impl", &Tensor::hasImpl)
     .def("get_coord", &Tensor::getCoord)
     .def("get_idx", &Tensor::getIdx)
@@ -114,7 +114,7 @@ void init_Tensor(py::module& m){
         }
     })
     .def_buffer([](Tensor& b) -> py::buffer_info {
-        const std::unique_ptr<TensorImpl>& tensorImpl = b.getImpl();
+        const std::shared_ptr<TensorImpl>& tensorImpl = b.getImpl();
 
         std::vector<size_t> dims;
         std::vector<size_t> strides;
-- 
GitLab