From 3a88f9be397b6b40a3d042154ac034985f1ace23 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Mon, 15 Jan 2024 16:07:40 +0100 Subject: [PATCH] Fixed pybind issues --- include/aidge/data/Tensor.hpp | 2 +- python_binding/data/pybind_Tensor.cpp | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index a6d0ce341..36947ca7f 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -32,7 +32,7 @@ namespace Aidge { * Contains a pointer to an actual contiguous implementation of data. */ class Tensor : public Data, - public Registrable<Tensor, std::tuple<std::string, DataType>, std::shared_ptr<TensorImpl>(int device, NbElts_t length)> { + public Registrable<Tensor, std::tuple<std::string, DataType>, std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)> { private: DataType mDataType; /** enum to specify data type. */ std::vector<DimSize_t> mDims; /** Dimensions of the tensor. */ diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index fa109a9af..4f760a65f 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -30,7 +30,7 @@ void addCtor(py::class_<Tensor, Data, Registrable<Tensor, std::tuple<std::string, DataType>, - std::unique_ptr<TensorImpl>(const Tensor&)>>& mTensor){ + std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>>& mTensor){ mTensor.def(py::init([]( py::array_t<T, py::array::c_style | py::array::forcecast> b) { /* Request a buffer descriptor from Python */ py::buffer_info info = b.request(); @@ -58,16 +58,16 @@ void addCtor(py::class_<Tensor, void init_Tensor(py::module& m){ py::class_<Registrable<Tensor, std::tuple<std::string, DataType>, - std::unique_ptr<TensorImpl>(const Tensor&)>, + std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>, std::shared_ptr<Registrable<Tensor, std::tuple<std::string, DataType>, - std::unique_ptr<TensorImpl>(const Tensor&)>>>(m,"TensorRegistrable"); + std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>>>(m,"TensorRegistrable"); py::class_<Tensor, std::shared_ptr<Tensor>, Data, Registrable<Tensor, std::tuple<std::string, DataType>, - std::unique_ptr<TensorImpl>(const Tensor&)>> pyClassTensor + std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>> pyClassTensor (m,"Tensor", py::multiple_inheritance(), py::buffer_protocol()); pyClassTensor.def(py::init<>()) @@ -76,7 +76,7 @@ void init_Tensor(py::module& m){ .def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims) .def("dtype", &Tensor::dataType) .def("size", &Tensor::size) - .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&)) &Tensor::resize) + .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize) .def("has_impl", &Tensor::hasImpl) .def("get_coord", &Tensor::getCoord) .def("get_idx", &Tensor::getIdx) @@ -114,7 +114,7 @@ void init_Tensor(py::module& m){ } }) .def_buffer([](Tensor& b) -> py::buffer_info { - const std::unique_ptr<TensorImpl>& tensorImpl = b.getImpl(); + const std::shared_ptr<TensorImpl>& tensorImpl = b.getImpl(); std::vector<size_t> dims; std::vector<size_t> strides; -- GitLab