diff --git a/python_binding/operator/pybind_GenericOperator.cpp b/python_binding/operator/pybind_GenericOperator.cpp index 68f43bb76571bceb501966abf74e47303533bb15..2b146c768003fa825b2cf6cacf88c6767c567a66 100644 --- a/python_binding/operator/pybind_GenericOperator.cpp +++ b/python_binding/operator/pybind_GenericOperator.cpp @@ -27,7 +27,31 @@ void init_GenericOperator(py::module& m) { .def("compute_output_dims", &GenericOperator_Op::computeOutputDims) .def("set_compute_output_dims", &GenericOperator_Op::setComputeOutputDims, py::arg("computation_function")); - m.def("GenericOperator", &GenericOperator, py::arg("type"), py::arg("nb_data"), py::arg("nb_param"), py::arg("nb_out"), - py::arg("name") = ""); + // &GenericOperator + m.def("GenericOperator", + []( const std::string& type, + IOIndex_t nbData, + IOIndex_t nbParam, + IOIndex_t nbOut, + const std::string& name = "", + const py::kwargs kwargs) { + std::shared_ptr<Node> genericNode = GenericOperator( + type, + nbData, + nbParam, + nbOut, + name + ); + if (kwargs){ + std::shared_ptr<GenericOperator_Op> gop = std::static_pointer_cast<GenericOperator_Op>(genericNode->getOperator()); + for (auto item : kwargs) { + std::string key = py::cast<std::string>(item.first); + py::object value = py::reinterpret_borrow<py::object>(item.second); + gop->setAttrPy(key, std::move(value)); + } + } + return genericNode; + } + , py::arg("type"), py::arg("nb_data"), py::arg("nb_param"), py::arg("nb_out"), py::arg("name") = ""); } } // namespace Aidge