From b2f5c6e9f901fd6424736521f78c4db1ad65a92d Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Mon, 29 Jan 2024 10:22:14 +0000 Subject: [PATCH] Add possibility to select backend when creating Tensor. --- python_binding/data/pybind_Tensor.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index 688a519e5..6d6f20ebe 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -31,24 +31,26 @@ void addCtor(py::class_<Tensor, Registrable<Tensor, std::tuple<std::string, DataType>, std::unique_ptr<TensorImpl>(const Tensor&)>>& mTensor){ - mTensor.def(py::init([]( py::array_t<T, py::array::c_style | py::array::forcecast> b) { + mTensor.def(py::init([]( + py::array_t<T, py::array::c_style | py::array::forcecast> b, + std::string backend = "cpu") { /* Request a buffer descriptor from Python */ py::buffer_info info = b.request(); Tensor* newTensor = new Tensor(); newTensor->setDataType(NativeType<T>::type); const std::vector<DimSize_t> dims(info.shape.begin(), info.shape.end()); newTensor->resize(dims); - // TODO : Find a better way to choose backend + std::set<std::string> availableBackends = Tensor::getAvailableBackends(); - if (availableBackends.find("cpu") != availableBackends.end()){ - newTensor->setBackend("cpu"); + if (availableBackends.find(backend) != availableBackends.end()){ + newTensor->setBackend(backend); newTensor->getImpl()->copyFromHost(static_cast<T*>(info.ptr), newTensor->size()); }else{ - printf("Warning : Could not use aidge_cpu backend, verify you have `import aidge_cpu`\n"); + AIDGE_THROW_OR_ABORT(py::value_error, "Could not find backend %s, verify you have `import aidge_backend_%s`.\n", backend.c_str(), backend.c_str()); } return newTensor; - })) + }), py::arg("array"), py::arg("backend")="cpu") .def("__setitem__", (void (Tensor::*)(std::size_t, T)) &Tensor::set) .def("__setitem__", (void (Tensor::*)(std::vector<std::size_t>, T)) &Tensor::set) ; -- GitLab