From 3934a754664973c765883e06aa848923f8cef8b3 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Wed, 13 Nov 2024 09:25:22 +0000 Subject: [PATCH] [Upd] Make GenericOperator registrable. --- include/aidge/operator/GenericOperator.hpp | 2 +- python_binding/operator/pybind_GenericOperator.cpp | 2 ++ src/operator/GenericOperator.cpp | 12 ++++++++++-- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/include/aidge/operator/GenericOperator.hpp b/include/aidge/operator/GenericOperator.hpp index 89b2c06a5..327f4f7c3 100644 --- a/include/aidge/operator/GenericOperator.hpp +++ b/include/aidge/operator/GenericOperator.hpp @@ -26,7 +26,7 @@ namespace Aidge { class GenericOperator_Op : public OperatorTensor, - public Registrable<GenericOperator_Op, std::string, std::function<std::unique_ptr<OperatorImpl>(std::shared_ptr<GenericOperator_Op>)>> { + public Registrable<GenericOperator_Op, std::array<std::string, 2>, std::function<std::shared_ptr<OperatorImpl>(const GenericOperator_Op &)>> { private: using ComputeDimsFunc = std::function<std::vector<std::vector<size_t>>(const std::vector<std::vector<size_t>>&)>; diff --git a/python_binding/operator/pybind_GenericOperator.cpp b/python_binding/operator/pybind_GenericOperator.cpp index 6af8fef88..f125291fa 100644 --- a/python_binding/operator/pybind_GenericOperator.cpp +++ b/python_binding/operator/pybind_GenericOperator.cpp @@ -64,5 +64,7 @@ void init_GenericOperator(py::module& m) { } return genericNode; }, py::arg("type"), py::arg("nb_data"), py::arg("nb_param"), py::arg("nb_out"), py::arg("name") = ""); + + declare_registrable<GenericOperator_Op>(m, "GenericOperatorOp"); } } // namespace Aidge diff --git a/src/operator/GenericOperator.cpp b/src/operator/GenericOperator.cpp index 0f90a5a58..c5bca9240 100644 --- a/src/operator/GenericOperator.cpp +++ b/src/operator/GenericOperator.cpp @@ -86,7 +86,15 @@ bool Aidge::GenericOperator_Op::forwardDims(bool /*allowDataDependency*/) { } void Aidge::GenericOperator_Op::setBackend(const std::string & name, DeviceIdx_t device) { - Log::warn("GenericOperator::setBackend(): cannot set backend for a generic operator, as no implementation has been provided!"); + if (Registrar<GenericOperator_Op>::exists({name, type()})) { + // A custom implementation exists for this meta operator + mImpl = Registrar<GenericOperator_Op>::create({name, type()})(*this); + }else{ + Log::warn("GenericOperator::setBackend(): cannot set backend for a generic operator, as no implementation has been provided!"); + } + + + for (std::size_t i = 0; i < nbOutputs(); ++i) { mOutputs[i]->setBackend(name, device); @@ -108,4 +116,4 @@ std::shared_ptr<Aidge::Node> Aidge::GenericOperator(const std::string& type, Aidge::IOIndex_t nbOut, const std::string& name) { return std::make_shared<Node>(std::make_shared<GenericOperator_Op>(type, nbData, nbParam, nbOut), name); -} \ No newline at end of file +} -- GitLab