Skip to content
Snippets Groups Projects
Commit 759c001f authored by Cyril Moineau's avatar Cyril Moineau
Browse files

[Fix] Make Unsqueeze registrable

parent a547ec1c
No related branches found
No related tags found
1 merge request!333[Fix] Attribute snake case
...@@ -23,26 +23,25 @@ void init_Unsqueeze(py::module &m) { ...@@ -23,26 +23,25 @@ void init_Unsqueeze(py::module &m) {
py::class_<Unsqueeze_Op, std::shared_ptr<Unsqueeze_Op>, OperatorTensor>( py::class_<Unsqueeze_Op, std::shared_ptr<Unsqueeze_Op>, OperatorTensor>(
m, "UnsqueezeOp", py::multiple_inheritance(), m, "UnsqueezeOp", py::multiple_inheritance(),
R"mydelimiter( R"mydelimiter(
Initialize an unsqueeze operator. Initialize an unsqueeze operator.
:param axes : axes to unsqueeze between [-r;r-1] :param axes: axes to unsqueeze between [-r;r-1] with r = input_tensor.nbDims() + len(axes)
with r = input_tensor.nbDims() + len(axes) :type axes: :py:class: List[Int]
:type axes : :py:class: List[Int]
)mydelimiter") )mydelimiter")
// Here we bind the methods of the Unsqueeze_Op that will want to access // Here we bind the methods of the Unsqueeze_Op that will want to access
.def("get_inputs_name", &Unsqueeze_Op::getInputsName) .def("get_inputs_name", &Unsqueeze_Op::getInputsName)
.def("get_outputs_name", &Unsqueeze_Op::getOutputsName) .def("get_outputs_name", &Unsqueeze_Op::getOutputsName)
.def("axes", &Unsqueeze_Op::axes); .def_readonly_static("Type", &Unsqueeze_Op::Type)
// Here we bind the constructor of the Unsqueeze Node. We add an argument for ;
// each attribute of the operator (in here we only have 'axes') and the last
// argument is the node's name. declare_registrable<Unsqueeze_Op>(m, "UnsqueezeOp");
m.def("Unsqueeze", &Unsqueeze, py::arg("axes") = std::vector<int8_t>({}), m.def("Unsqueeze", &Unsqueeze, py::arg("axes") = std::vector<int8_t>({}),
py::arg("name") = "", py::arg("name") = "",
R"mydelimiter( R"mydelimiter(
Initialize a node containing an unsqueeze operator. Initialize a node containing an unsqueeze operator.
:param axes : axes to unsqueeze between [-r;r-1] :param axes: axes to unsqueeze between [-r;r-1] with r = input_tensor.nbDims() + len(axes)
with r = input_tensor.nbDims() + len(axes) :type axes: :py:class: List[Int]
:type axes : :py:class: List[Int] :param name: name of the node.
:param name : name of the node. )mydelimiter");
)mydelimiter"); }
}
} // namespace Aidge } // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment