diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index 688a519e593dcde1fe69e3324c81163250eeb42b..6d6f20ebe9377ce177d936c00f097fee76954bd9 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -31,24 +31,26 @@ void addCtor(py::class_<Tensor, Registrable<Tensor, std::tuple<std::string, DataType>, std::unique_ptr<TensorImpl>(const Tensor&)>>& mTensor){ - mTensor.def(py::init([]( py::array_t<T, py::array::c_style | py::array::forcecast> b) { + mTensor.def(py::init([]( + py::array_t<T, py::array::c_style | py::array::forcecast> b, + std::string backend = "cpu") { /* Request a buffer descriptor from Python */ py::buffer_info info = b.request(); Tensor* newTensor = new Tensor(); newTensor->setDataType(NativeType<T>::type); const std::vector<DimSize_t> dims(info.shape.begin(), info.shape.end()); newTensor->resize(dims); - // TODO : Find a better way to choose backend + std::set<std::string> availableBackends = Tensor::getAvailableBackends(); - if (availableBackends.find("cpu") != availableBackends.end()){ - newTensor->setBackend("cpu"); + if (availableBackends.find(backend) != availableBackends.end()){ + newTensor->setBackend(backend); newTensor->getImpl()->copyFromHost(static_cast<T*>(info.ptr), newTensor->size()); }else{ - printf("Warning : Could not use aidge_cpu backend, verify you have `import aidge_cpu`\n"); + AIDGE_THROW_OR_ABORT(py::value_error, "Could not find backend %s, verify you have `import aidge_backend_%s`.\n", backend.c_str(), backend.c_str()); } return newTensor; - })) + }), py::arg("array"), py::arg("backend")="cpu") .def("__setitem__", (void (Tensor::*)(std::size_t, T)) &Tensor::set) .def("__setitem__", (void (Tensor::*)(std::vector<std::size_t>, T)) &Tensor::set) ;