Skip to content
Snippets Groups Projects
Commit 3a88f9be authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed pybind issues

parent 8bc10d7c
No related branches found
No related tags found
No related merge requests found
......@@ -32,7 +32,7 @@ namespace Aidge {
* Contains a pointer to an actual contiguous implementation of data.
*/
class Tensor : public Data,
public Registrable<Tensor, std::tuple<std::string, DataType>, std::shared_ptr<TensorImpl>(int device, NbElts_t length)> {
public Registrable<Tensor, std::tuple<std::string, DataType>, std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)> {
private:
DataType mDataType; /** enum to specify data type. */
std::vector<DimSize_t> mDims; /** Dimensions of the tensor. */
......
......@@ -30,7 +30,7 @@ void addCtor(py::class_<Tensor,
Data,
Registrable<Tensor,
std::tuple<std::string, DataType>,
std::unique_ptr<TensorImpl>(const Tensor&)>>& mTensor){
std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>>& mTensor){
mTensor.def(py::init([]( py::array_t<T, py::array::c_style | py::array::forcecast> b) {
/* Request a buffer descriptor from Python */
py::buffer_info info = b.request();
......@@ -58,16 +58,16 @@ void addCtor(py::class_<Tensor,
void init_Tensor(py::module& m){
py::class_<Registrable<Tensor,
std::tuple<std::string, DataType>,
std::unique_ptr<TensorImpl>(const Tensor&)>,
std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>,
std::shared_ptr<Registrable<Tensor,
std::tuple<std::string, DataType>,
std::unique_ptr<TensorImpl>(const Tensor&)>>>(m,"TensorRegistrable");
std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>>>(m,"TensorRegistrable");
py::class_<Tensor, std::shared_ptr<Tensor>,
Data,
Registrable<Tensor,
std::tuple<std::string, DataType>,
std::unique_ptr<TensorImpl>(const Tensor&)>> pyClassTensor
std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>> pyClassTensor
(m,"Tensor", py::multiple_inheritance(), py::buffer_protocol());
pyClassTensor.def(py::init<>())
......@@ -76,7 +76,7 @@ void init_Tensor(py::module& m){
.def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims)
.def("dtype", &Tensor::dataType)
.def("size", &Tensor::size)
.def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&)) &Tensor::resize)
.def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize)
.def("has_impl", &Tensor::hasImpl)
.def("get_coord", &Tensor::getCoord)
.def("get_idx", &Tensor::getIdx)
......@@ -114,7 +114,7 @@ void init_Tensor(py::module& m){
}
})
.def_buffer([](Tensor& b) -> py::buffer_info {
const std::unique_ptr<TensorImpl>& tensorImpl = b.getImpl();
const std::shared_ptr<TensorImpl>& tensorImpl = b.getImpl();
std::vector<size_t> dims;
std::vector<size_t> strides;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment