diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index fe9b044e2309eb7e724d6648b84c044d7407bafb..e31d8c7d281c1c7ee5be77b4e76c797a32d53917 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -28,7 +28,7 @@ enum class ProdAttr { Constant }; class Producer_Op : public OperatorTensor, - public Registrable<Producer_Op, std::string, std::unique_ptr<OperatorImpl>( + public Registrable<Producer_Op, std::string, std::shared_ptr<OperatorImpl>( const Producer_Op &)>, public StaticAttributes<ProdAttr, bool> { public: @@ -88,7 +88,7 @@ public: inline const std::vector<DimSize_t> dims() const noexcept { return mOutputs[0]->dims(); } void setBackend(const std::string& name, DeviceIdx_t device = 0) override { - mImpl = Registrar<Producer_Op>::create(name)(*this); + SET_IMPL_MACRO(Producer_Op, *this, name); mOutputs[0]->setBackend(name, device); } diff --git a/python_binding/operator/pybind_Producer.cpp b/python_binding/operator/pybind_Producer.cpp index 3caa438d18b3919dbedcf66e4ba53b92b84a50b5..025c8c5dd1651b3466a22e88f0966a7f51d2c109 100644 --- a/python_binding/operator/pybind_Producer.cpp +++ b/python_binding/operator/pybind_Producer.cpp @@ -26,6 +26,7 @@ void declare_Producer(py::module &m) { // m.def(("Producer_" + std::to_string(DIM)+"D").c_str(), py::overload_cast<shared_ptr<Node>&>(&Producer<DIM>), py::arg("dims"), py::arg("name")); m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::array<DimSize_t, DIM>&, const std::string&, bool)>(&Producer), py::arg("dims"), py::arg("name") = "", py::arg("constant") = false); + } @@ -39,7 +40,7 @@ void init_Producer(py::module &m) { .def("get_outputs_name", &Producer_Op::getOutputsName) .def("attributes_name", &Producer_Op::staticGetAttrsName); m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&, bool)>(&Producer), py::arg("tensor"), py::arg("name") = "", py::arg("constant") = false); - + declare_registrable<Producer_Op>(m, "ProducerOp"); declare_Producer<1>(m); declare_Producer<2>(m); declare_Producer<3>(m);