Skip to content
Snippets Groups Projects
Commit b2f5c6e9 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Add possibility to select backend when creating Tensor.

parent 1c7872c6
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!73Quality of life
Pipeline #37744 failed
...@@ -31,24 +31,26 @@ void addCtor(py::class_<Tensor, ...@@ -31,24 +31,26 @@ void addCtor(py::class_<Tensor,
Registrable<Tensor, Registrable<Tensor,
std::tuple<std::string, DataType>, std::tuple<std::string, DataType>,
std::unique_ptr<TensorImpl>(const Tensor&)>>& mTensor){ std::unique_ptr<TensorImpl>(const Tensor&)>>& mTensor){
mTensor.def(py::init([]( py::array_t<T, py::array::c_style | py::array::forcecast> b) { mTensor.def(py::init([](
py::array_t<T, py::array::c_style | py::array::forcecast> b,
std::string backend = "cpu") {
/* Request a buffer descriptor from Python */ /* Request a buffer descriptor from Python */
py::buffer_info info = b.request(); py::buffer_info info = b.request();
Tensor* newTensor = new Tensor(); Tensor* newTensor = new Tensor();
newTensor->setDataType(NativeType<T>::type); newTensor->setDataType(NativeType<T>::type);
const std::vector<DimSize_t> dims(info.shape.begin(), info.shape.end()); const std::vector<DimSize_t> dims(info.shape.begin(), info.shape.end());
newTensor->resize(dims); newTensor->resize(dims);
// TODO : Find a better way to choose backend
std::set<std::string> availableBackends = Tensor::getAvailableBackends(); std::set<std::string> availableBackends = Tensor::getAvailableBackends();
if (availableBackends.find("cpu") != availableBackends.end()){ if (availableBackends.find(backend) != availableBackends.end()){
newTensor->setBackend("cpu"); newTensor->setBackend(backend);
newTensor->getImpl()->copyFromHost(static_cast<T*>(info.ptr), newTensor->size()); newTensor->getImpl()->copyFromHost(static_cast<T*>(info.ptr), newTensor->size());
}else{ }else{
printf("Warning : Could not use aidge_cpu backend, verify you have `import aidge_cpu`\n"); AIDGE_THROW_OR_ABORT(py::value_error, "Could not find backend %s, verify you have `import aidge_backend_%s`.\n", backend.c_str(), backend.c_str());
} }
return newTensor; return newTensor;
})) }), py::arg("array"), py::arg("backend")="cpu")
.def("__setitem__", (void (Tensor::*)(std::size_t, T)) &Tensor::set) .def("__setitem__", (void (Tensor::*)(std::size_t, T)) &Tensor::set)
.def("__setitem__", (void (Tensor::*)(std::vector<std::size_t>, T)) &Tensor::set) .def("__setitem__", (void (Tensor::*)(std::vector<std::size_t>, T)) &Tensor::set)
; ;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment