Skip to content
Snippets Groups Projects
Commit 897f3cb8 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

[Fix] Add default arg axis=0 for concat

parent ee77d3c4
No related branches found
No related tags found
1 merge request!333[Fix] Attribute snake case
...@@ -56,7 +56,7 @@ enum class ConcatAttr { ...@@ -56,7 +56,7 @@ enum class ConcatAttr {
* *
* The specified axis determines the direction of concatenating. * The specified axis determines the direction of concatenating.
*/ */
Axis Axis
}; };
/** /**
...@@ -107,7 +107,7 @@ public: ...@@ -107,7 +107,7 @@ public:
* @param[in] nbIn Number of input tensors. * @param[in] nbIn Number of input tensors.
* @param[in] axis Axis along which concatenation is performed. * @param[in] axis Axis along which concatenation is performed.
*/ */
Concat_Op(const IOIndex_t nbIn, const std::int32_t axis); Concat_Op(const IOIndex_t nbIn, const std::int32_t axis = 0);
/** /**
* @brief Copy-constructor. Copies the operator attributes and its output tensors, * @brief Copy-constructor. Copies the operator attributes and its output tensors,
......
...@@ -24,30 +24,30 @@ void init_Concat(py::module& m) { ...@@ -24,30 +24,30 @@ void init_Concat(py::module& m) {
R"mydelimiter( R"mydelimiter(
Initialize a Concat operator. Initialize a Concat operator.
:param nb_inputs : The number of input tensors to concatenate. :param nb_inputs: The number of input tensors to concatenate.
:type nb_inputs : :py:class:`int` :type nb_inputs: :py:class:`int`
:param axis : The axis along which to concatenate the tensors. :param axis: The axis along which to concatenate the tensors, default=0.
:type axis : :py:class:`int` :type axis: :py:class:`int`
)mydelimiter") )mydelimiter")
.def(py::init<const IOIndex_t, const int>(), .def(py::init<const IOIndex_t, const int>(),
py::arg("nb_inputs"), py::arg("nb_inputs"),
py::arg("axis")) py::arg("axis") = 0)
.def_static("get_inputs_name", &Concat_Op::getInputsName) .def_static("get_inputs_name", &Concat_Op::getInputsName)
.def_static("get_outputs_name", &Concat_Op::getOutputsName) .def_static("get_outputs_name", &Concat_Op::getOutputsName)
.def_readonly_static("Type", &Concat_Op::Type); .def_readonly_static("Type", &Concat_Op::Type);
declare_registrable<Concat_Op>(m, "ConcatOp"); declare_registrable<Concat_Op>(m, "ConcatOp");
m.def("Concat", &Concat, py::arg("nb_inputs"), py::arg("axis"), py::arg("name") = "", m.def("Concat", &Concat, py::arg("nb_inputs"), py::arg("axis") = 0, py::arg("name") = "",
R"mydelimiter( R"mydelimiter(
Initialize a node containing a Concat operator. Initialize a node containing a Concat operator.
:param nb_inputs : The number of input tensors to concatenate. :param nb_inputs: The number of input tensors to concatenate.
:type nb_inputs : :py:class:`int` :type nb_inputs: :py:class:`int`
:param axis : The axis along which to concatenate the tensors. :param axis: The axis along which to concatenate the tensors.
:type axis : :py:class:`int` :type axis: :py:class:`int`
:param name : Name of the node. :param name: Name of the node.
:type name : :py:class:`str` :type name: :py:class:`str`
)mydelimiter"); )mydelimiter");
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment