From d68006ece76c74c25dc6a33a3394fe320dd1ec87 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Thu, 25 Jan 2024 12:54:54 +0000 Subject: [PATCH] [PyBind] Update GenericOp ctor to use kwargs to set attributes. --- .../operator/pybind_GenericOperator.cpp | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/python_binding/operator/pybind_GenericOperator.cpp b/python_binding/operator/pybind_GenericOperator.cpp index 68f43bb76..2b146c768 100644 --- a/python_binding/operator/pybind_GenericOperator.cpp +++ b/python_binding/operator/pybind_GenericOperator.cpp @@ -27,7 +27,31 @@ void init_GenericOperator(py::module& m) { .def("compute_output_dims", &GenericOperator_Op::computeOutputDims) .def("set_compute_output_dims", &GenericOperator_Op::setComputeOutputDims, py::arg("computation_function")); - m.def("GenericOperator", &GenericOperator, py::arg("type"), py::arg("nb_data"), py::arg("nb_param"), py::arg("nb_out"), - py::arg("name") = ""); + // &GenericOperator + m.def("GenericOperator", + []( const std::string& type, + IOIndex_t nbData, + IOIndex_t nbParam, + IOIndex_t nbOut, + const std::string& name = "", + const py::kwargs kwargs) { + std::shared_ptr<Node> genericNode = GenericOperator( + type, + nbData, + nbParam, + nbOut, + name + ); + if (kwargs){ + std::shared_ptr<GenericOperator_Op> gop = std::static_pointer_cast<GenericOperator_Op>(genericNode->getOperator()); + for (auto item : kwargs) { + std::string key = py::cast<std::string>(item.first); + py::object value = py::reinterpret_borrow<py::object>(item.second); + gop->setAttrPy(key, std::move(value)); + } + } + return genericNode; + } + , py::arg("type"), py::arg("nb_data"), py::arg("nb_param"), py::arg("nb_out"), py::arg("name") = ""); } } // namespace Aidge -- GitLab