Skip to content
Snippets Groups Projects
Commit 897f3cb8 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

[Fix] Add default arg axis=0 for concat

parent ee77d3c4
No related branches found
No related tags found
1 merge request!333[Fix] Attribute snake case
......@@ -56,7 +56,7 @@ enum class ConcatAttr {
*
* The specified axis determines the direction of concatenating.
*/
Axis
Axis
};
/**
......@@ -107,7 +107,7 @@ public:
* @param[in] nbIn Number of input tensors.
* @param[in] axis Axis along which concatenation is performed.
*/
Concat_Op(const IOIndex_t nbIn, const std::int32_t axis);
Concat_Op(const IOIndex_t nbIn, const std::int32_t axis = 0);
/**
* @brief Copy-constructor. Copies the operator attributes and its output tensors,
......
......@@ -24,30 +24,30 @@ void init_Concat(py::module& m) {
R"mydelimiter(
Initialize a Concat operator.
:param nb_inputs : The number of input tensors to concatenate.
:type nb_inputs : :py:class:`int`
:param axis : The axis along which to concatenate the tensors.
:type axis : :py:class:`int`
:param nb_inputs: The number of input tensors to concatenate.
:type nb_inputs: :py:class:`int`
:param axis: The axis along which to concatenate the tensors, default=0.
:type axis: :py:class:`int`
)mydelimiter")
.def(py::init<const IOIndex_t, const int>(),
py::arg("nb_inputs"),
py::arg("axis"))
py::arg("axis") = 0)
.def_static("get_inputs_name", &Concat_Op::getInputsName)
.def_static("get_outputs_name", &Concat_Op::getOutputsName)
.def_readonly_static("Type", &Concat_Op::Type);
declare_registrable<Concat_Op>(m, "ConcatOp");
m.def("Concat", &Concat, py::arg("nb_inputs"), py::arg("axis"), py::arg("name") = "",
m.def("Concat", &Concat, py::arg("nb_inputs"), py::arg("axis") = 0, py::arg("name") = "",
R"mydelimiter(
Initialize a node containing a Concat operator.
:param nb_inputs : The number of input tensors to concatenate.
:type nb_inputs : :py:class:`int`
:param axis : The axis along which to concatenate the tensors.
:type axis : :py:class:`int`
:param name : Name of the node.
:type name : :py:class:`str`
:param nb_inputs: The number of input tensors to concatenate.
:type nb_inputs: :py:class:`int`
:param axis: The axis along which to concatenate the tensors.
:type axis: :py:class:`int`
:param name: Name of the node.
:type name: :py:class:`str`
)mydelimiter");
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment