diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index 0731498dd3e06541ed82a86a98c2ae0bb355f413..29fada1c5a0eea2a47f4a6921d4ec3dd2d069094 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -28,7 +28,7 @@ enum class ProdAttr { Constant }; class Producer_Op : public OperatorTensor, - public Registrable<Producer_Op, std::string, std::unique_ptr<OperatorImpl>( + public Registrable<Producer_Op, std::string, std::shared_ptr<OperatorImpl>( const Producer_Op &)>, public StaticAttributes<ProdAttr, bool> { public: @@ -92,9 +92,7 @@ public: inline const std::vector<DimSize_t> dims() const noexcept { return mOutputs[0]->dims(); } void setBackend(const std::string& name, DeviceIdx_t device = 0) override { - if (Registrar<Producer_Op>::exists({name})) { - mImpl = Registrar<Producer_Op>::create({name})(*this); - } + SET_IMPL_MACRO(Producer_Op, *this, name); mOutputs[0]->setBackend(name, device); } diff --git a/include/aidge/utils/Registrar.hpp b/include/aidge/utils/Registrar.hpp index b2a98113f17065a30ddd52b69d897857c01e6854..a5bd260ec189ac998134b738ca1ae757f2a0038c 100644 --- a/include/aidge/utils/Registrar.hpp +++ b/include/aidge/utils/Registrar.hpp @@ -139,8 +139,10 @@ void declare_registrable(py::module& m, const std::string& class_name){ (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \ } #else -#define SET_IMPL_MACRO(T_Op, op, backend_name) \ - (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); +#define SET_IMPL_MACRO(T_Op, op, backend_name) \ + if (Registrar<T_Op>::exists(backend_name)) { \ + (op).setImpl(Registrar<T_Op>::create(backend_name)(op)); \ + } #endif } diff --git a/python_binding/operator/pybind_Producer.cpp b/python_binding/operator/pybind_Producer.cpp index 3caa438d18b3919dbedcf66e4ba53b92b84a50b5..025c8c5dd1651b3466a22e88f0966a7f51d2c109 100644 --- a/python_binding/operator/pybind_Producer.cpp +++ b/python_binding/operator/pybind_Producer.cpp @@ -26,6 +26,7 @@ void declare_Producer(py::module &m) { // m.def(("Producer_" + std::to_string(DIM)+"D").c_str(), py::overload_cast<shared_ptr<Node>&>(&Producer<DIM>), py::arg("dims"), py::arg("name")); m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::array<DimSize_t, DIM>&, const std::string&, bool)>(&Producer), py::arg("dims"), py::arg("name") = "", py::arg("constant") = false); + } @@ -39,7 +40,7 @@ void init_Producer(py::module &m) { .def("get_outputs_name", &Producer_Op::getOutputsName) .def("attributes_name", &Producer_Op::staticGetAttrsName); m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&, bool)>(&Producer), py::arg("tensor"), py::arg("name") = "", py::arg("constant") = false); - + declare_registrable<Producer_Op>(m, "ProducerOp"); declare_Producer<1>(m); declare_Producer<2>(m); declare_Producer<3>(m);