diff --git a/include/aidge/operator/Concat.hpp b/include/aidge/operator/Concat.hpp index 83914b6730bde238d5e2e7b4391bd034c8f4d146..ad31ef1a38a881821293a912342948126f83d28a 100644 --- a/include/aidge/operator/Concat.hpp +++ b/include/aidge/operator/Concat.hpp @@ -56,7 +56,7 @@ enum class ConcatAttr { * * The specified axis determines the direction of concatenating. */ - Axis + Axis }; /** @@ -107,7 +107,7 @@ public: * @param[in] nbIn Number of input tensors. * @param[in] axis Axis along which concatenation is performed. */ - Concat_Op(const IOIndex_t nbIn, const std::int32_t axis); + Concat_Op(const IOIndex_t nbIn, const std::int32_t axis = 0); /** * @brief Copy-constructor. Copies the operator attributes and its output tensors, diff --git a/python_binding/operator/pybind_Concat.cpp b/python_binding/operator/pybind_Concat.cpp index 9e1b3de9e7b1f6bd8c84779196c1918294cedb18..d2410b03a7675b181aa012d54b9fa1c93de0ef64 100644 --- a/python_binding/operator/pybind_Concat.cpp +++ b/python_binding/operator/pybind_Concat.cpp @@ -24,30 +24,30 @@ void init_Concat(py::module& m) { R"mydelimiter( Initialize a Concat operator. - :param nb_inputs : The number of input tensors to concatenate. - :type nb_inputs : :py:class:`int` - :param axis : The axis along which to concatenate the tensors. - :type axis : :py:class:`int` + :param nb_inputs: The number of input tensors to concatenate. + :type nb_inputs: :py:class:`int` + :param axis: The axis along which to concatenate the tensors, default=0. + :type axis: :py:class:`int` )mydelimiter") .def(py::init<const IOIndex_t, const int>(), py::arg("nb_inputs"), - py::arg("axis")) + py::arg("axis") = 0) .def_static("get_inputs_name", &Concat_Op::getInputsName) .def_static("get_outputs_name", &Concat_Op::getOutputsName) .def_readonly_static("Type", &Concat_Op::Type); declare_registrable<Concat_Op>(m, "ConcatOp"); - m.def("Concat", &Concat, py::arg("nb_inputs"), py::arg("axis"), py::arg("name") = "", + m.def("Concat", &Concat, py::arg("nb_inputs"), py::arg("axis") = 0, py::arg("name") = "", R"mydelimiter( Initialize a node containing a Concat operator. - :param nb_inputs : The number of input tensors to concatenate. - :type nb_inputs : :py:class:`int` - :param axis : The axis along which to concatenate the tensors. - :type axis : :py:class:`int` - :param name : Name of the node. - :type name : :py:class:`str` + :param nb_inputs: The number of input tensors to concatenate. + :type nb_inputs: :py:class:`int` + :param axis: The axis along which to concatenate the tensors. + :type axis: :py:class:`int` + :param name: Name of the node. + :type name: :py:class:`str` )mydelimiter"); }