From 0182775fd06a414b04892fec8e2a3c7479bb2382 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Fri, 14 Feb 2025 13:00:44 +0000 Subject: [PATCH] [Fix] Make Squeeze registrable + fix python doc. --- python_binding/operator/pybind_Squeeze.cpp | 44 +++++++++++----------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/python_binding/operator/pybind_Squeeze.cpp b/python_binding/operator/pybind_Squeeze.cpp index ca90fb46a..188ce745d 100644 --- a/python_binding/operator/pybind_Squeeze.cpp +++ b/python_binding/operator/pybind_Squeeze.cpp @@ -24,29 +24,29 @@ namespace Aidge { void init_Squeeze(py::module &m) { py::class_<Squeeze_Op, std::shared_ptr<Squeeze_Op>, OperatorTensor>( - m, "SqueezeOp", py::multiple_inheritance(), - R"mydelimiter( - Initialize squeeze operator - :param axes : axes to squeeze between [-r;r-1] - with r = input_tensor.nbDims() - & r in [-128 , 127] - :type axes : :py:class: List[Int] - )mydelimiter") - .def("get_inputs_name", &Squeeze_Op::getInputsName) - .def("get_outputs_name", &Squeeze_Op::getOutputsName) - .def("axes", &Squeeze_Op::axes); - // Here we bind the constructor of the Squeeze Node. We add an argument - // for each attribute of the operator (in here we only have 'axes') and - // the last argument is the node's name. - m.def("Squeeze", &Squeeze, py::arg("axes") = std::vector<int8_t>({}), + m, "SqueezeOp", py::multiple_inheritance(), + R"mydelimiter( + Initialize squeeze operator + :param axes: axes to squeeze between [-r;r-1] + with r = input_tensor.nbDims() + & r in [-128 , 127] + :type axes: :py:class: List[Int] + )mydelimiter") + .def("get_inputs_name", &Squeeze_Op::getInputsName) + .def("get_outputs_name", &Squeeze_Op::getOutputsName) + .def("axes", &Squeeze_Op::axes); + + declare_registrable<Squeeze_Op>(m, "SqueezeOp"); + m.def("Squeeze", &Squeeze, py::arg("axes") = std::vector<int8_t>({}), py::arg("name") = "", R"mydelimiter( - Initialize a node containing a squeeze operator. - :param axes : axes to squeeze between [-r;r-1] - with r = input_tensor.nbDims() - & r in [-128 , 127] - :type axes : :py:class: List[Int] - :param name : name of the node. -)mydelimiter"); + Initialize a node containing a squeeze operator. + :param axes: axes to squeeze between [-r;r-1] + with r = input_tensor.nbDims() + & r in [-128 , 127] + :type axes: :py:class: List[Int] + :param name: name of the node. + :type name: str + )mydelimiter"); } } // namespace Aidge -- GitLab