From b2f5c6e9f901fd6424736521f78c4db1ad65a92d Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Mon, 29 Jan 2024 10:22:14 +0000
Subject: [PATCH] Add possibility to select backend when creating Tensor.

---
 python_binding/data/pybind_Tensor.cpp | 14 ++++++++------
 1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp
index 688a519e5..6d6f20ebe 100644
--- a/python_binding/data/pybind_Tensor.cpp
+++ b/python_binding/data/pybind_Tensor.cpp
@@ -31,24 +31,26 @@ void addCtor(py::class_<Tensor,
                         Registrable<Tensor,
                                     std::tuple<std::string, DataType>,
                                     std::unique_ptr<TensorImpl>(const Tensor&)>>& mTensor){
-    mTensor.def(py::init([]( py::array_t<T, py::array::c_style | py::array::forcecast> b) {
+    mTensor.def(py::init([](
+        py::array_t<T, py::array::c_style | py::array::forcecast> b,
+        std::string backend = "cpu") {
         /* Request a buffer descriptor from Python */
         py::buffer_info info = b.request();
         Tensor* newTensor = new Tensor();
         newTensor->setDataType(NativeType<T>::type);
         const std::vector<DimSize_t> dims(info.shape.begin(), info.shape.end());
         newTensor->resize(dims);
-        // TODO : Find a better way to choose backend
+
         std::set<std::string> availableBackends = Tensor::getAvailableBackends();
-        if (availableBackends.find("cpu") != availableBackends.end()){
-            newTensor->setBackend("cpu");
+        if (availableBackends.find(backend) != availableBackends.end()){
+            newTensor->setBackend(backend);
             newTensor->getImpl()->copyFromHost(static_cast<T*>(info.ptr), newTensor->size());
         }else{
-            printf("Warning : Could not use aidge_cpu backend, verify you have `import aidge_cpu`\n");
+            AIDGE_THROW_OR_ABORT(py::value_error, "Could not find backend %s, verify you have `import aidge_backend_%s`.\n", backend.c_str(), backend.c_str());
         }
 
         return newTensor;
-    }))
+    }), py::arg("array"), py::arg("backend")="cpu")
     .def("__setitem__", (void (Tensor::*)(std::size_t, T)) &Tensor::set)
     .def("__setitem__", (void (Tensor::*)(std::vector<std::size_t>, T)) &Tensor::set)
     ;
-- 
GitLab