From 897f3cb8e868c867aad9f7e3d3b3c561cedc74f7 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Fri, 14 Feb 2025 12:53:17 +0000 Subject: [PATCH] [Fix] Add default arg axis=0 for concat --- include/aidge/operator/Concat.hpp | 4 ++-- python_binding/operator/pybind_Concat.cpp | 24 +++++++++++------------ 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/include/aidge/operator/Concat.hpp b/include/aidge/operator/Concat.hpp index 83914b673..ad31ef1a3 100644 --- a/include/aidge/operator/Concat.hpp +++ b/include/aidge/operator/Concat.hpp @@ -56,7 +56,7 @@ enum class ConcatAttr { * * The specified axis determines the direction of concatenating. */ - Axis + Axis }; /** @@ -107,7 +107,7 @@ public: * @param[in] nbIn Number of input tensors. * @param[in] axis Axis along which concatenation is performed. */ - Concat_Op(const IOIndex_t nbIn, const std::int32_t axis); + Concat_Op(const IOIndex_t nbIn, const std::int32_t axis = 0); /** * @brief Copy-constructor. Copies the operator attributes and its output tensors, diff --git a/python_binding/operator/pybind_Concat.cpp b/python_binding/operator/pybind_Concat.cpp index 9e1b3de9e..d2410b03a 100644 --- a/python_binding/operator/pybind_Concat.cpp +++ b/python_binding/operator/pybind_Concat.cpp @@ -24,30 +24,30 @@ void init_Concat(py::module& m) { R"mydelimiter( Initialize a Concat operator. - :param nb_inputs : The number of input tensors to concatenate. - :type nb_inputs : :py:class:`int` - :param axis : The axis along which to concatenate the tensors. - :type axis : :py:class:`int` + :param nb_inputs: The number of input tensors to concatenate. + :type nb_inputs: :py:class:`int` + :param axis: The axis along which to concatenate the tensors, default=0. + :type axis: :py:class:`int` )mydelimiter") .def(py::init<const IOIndex_t, const int>(), py::arg("nb_inputs"), - py::arg("axis")) + py::arg("axis") = 0) .def_static("get_inputs_name", &Concat_Op::getInputsName) .def_static("get_outputs_name", &Concat_Op::getOutputsName) .def_readonly_static("Type", &Concat_Op::Type); declare_registrable<Concat_Op>(m, "ConcatOp"); - m.def("Concat", &Concat, py::arg("nb_inputs"), py::arg("axis"), py::arg("name") = "", + m.def("Concat", &Concat, py::arg("nb_inputs"), py::arg("axis") = 0, py::arg("name") = "", R"mydelimiter( Initialize a node containing a Concat operator. - :param nb_inputs : The number of input tensors to concatenate. - :type nb_inputs : :py:class:`int` - :param axis : The axis along which to concatenate the tensors. - :type axis : :py:class:`int` - :param name : Name of the node. - :type name : :py:class:`str` + :param nb_inputs: The number of input tensors to concatenate. + :type nb_inputs: :py:class:`int` + :param axis: The axis along which to concatenate the tensors. + :type axis: :py:class:`int` + :param name: Name of the node. + :type name: :py:class:`str` )mydelimiter"); } -- GitLab