From 693ee69f34c55aff681f127b879276cac820a97e Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Mon, 16 Sep 2024 18:01:45 +0200 Subject: [PATCH] Allow to set a list of backend by order of preference --- include/aidge/operator/Operator.hpp | 1 + python_binding/operator/pybind_Operator.cpp | 3 ++- src/operator/Operator.cpp | 14 ++++++++++++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 05cd6e8ae..93e9664e2 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -129,6 +129,7 @@ public: } virtual void setBackend(const std::string& name, DeviceIdx_t device = 0) = 0; + void setBackend(const std::vector<std::pair<std::string, DeviceIdx_t>>& backends); virtual void setDataType(const DataType& dataType) const = 0; virtual void setDataFormat(const DataFormat& dataFormat) const = 0; diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp index dbf71a3ca..81a62f4ed 100644 --- a/python_binding/operator/pybind_Operator.cpp +++ b/python_binding/operator/pybind_Operator.cpp @@ -53,7 +53,8 @@ void init_Operator(py::module& m){ )mydelimiter") .def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data")) .def("set_datatype", &Operator::setDataType, py::arg("dataType")) - .def("set_backend", &Operator::setBackend, py::arg("name"), py::arg("device") = 0) + .def("set_backend", py::overload_cast<const std::string&, DeviceIdx_t>(&Operator::setBackend), py::arg("name"), py::arg("device") = 0) + .def("set_backend", py::overload_cast<const std::vector<std::pair<std::string, DeviceIdx_t>>&>(&Operator::setBackend), py::arg("backends")) .def("forward", &Operator::forward) // py::keep_alive forbide Python to garbage collect the implementation lambda as long as the Operator is not deleted ! .def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>()) diff --git a/src/operator/Operator.cpp b/src/operator/Operator.cpp index 762d5fda8..f15a7dc38 100644 --- a/src/operator/Operator.cpp +++ b/src/operator/Operator.cpp @@ -80,3 +80,17 @@ void Aidge::Operator::backward() { AIDGE_ASSERT(mImpl != nullptr, "backward(): an implementation is required for {}!", type()); mImpl->backward(); } + +void Aidge::Operator::setBackend(const std::vector<std::pair<std::string, DeviceIdx_t>>& backends) { + const auto& availableBackends = getAvailableBackends(); + // By default, try to set the last backend anyway + auto selectedBackend = backends.back(); + for (const auto& backend : backends) { + if (availableBackends.find(backend.first) != availableBackends.end()) { + selectedBackend = backend; + break; + } + } + + setBackend(selectedBackend.first, selectedBackend.second); +} -- GitLab