Skip to content
Snippets Groups Projects
Commit 0182775f authored by Cyril Moineau's avatar Cyril Moineau
Browse files

[Fix] Make Squeeze registrable + fix python doc.

parent 89533f70
No related branches found
No related tags found
3 merge requests!414Update version 0.5.1 -> 0.6.0,!408[Add] Dropout Operator,!333[Fix] Attribute snake case
...@@ -24,29 +24,29 @@ namespace Aidge { ...@@ -24,29 +24,29 @@ namespace Aidge {
void init_Squeeze(py::module &m) { void init_Squeeze(py::module &m) {
py::class_<Squeeze_Op, std::shared_ptr<Squeeze_Op>, OperatorTensor>( py::class_<Squeeze_Op, std::shared_ptr<Squeeze_Op>, OperatorTensor>(
m, "SqueezeOp", py::multiple_inheritance(), m, "SqueezeOp", py::multiple_inheritance(),
R"mydelimiter( R"mydelimiter(
Initialize squeeze operator Initialize squeeze operator
:param axes : axes to squeeze between [-r;r-1] :param axes: axes to squeeze between [-r;r-1]
with r = input_tensor.nbDims() with r = input_tensor.nbDims()
& r in [-128 , 127] & r in [-128 , 127]
:type axes : :py:class: List[Int] :type axes: :py:class: List[Int]
)mydelimiter") )mydelimiter")
.def("get_inputs_name", &Squeeze_Op::getInputsName) .def("get_inputs_name", &Squeeze_Op::getInputsName)
.def("get_outputs_name", &Squeeze_Op::getOutputsName) .def("get_outputs_name", &Squeeze_Op::getOutputsName)
.def("axes", &Squeeze_Op::axes); .def("axes", &Squeeze_Op::axes);
// Here we bind the constructor of the Squeeze Node. We add an argument
// for each attribute of the operator (in here we only have 'axes') and declare_registrable<Squeeze_Op>(m, "SqueezeOp");
// the last argument is the node's name. m.def("Squeeze", &Squeeze, py::arg("axes") = std::vector<int8_t>({}),
m.def("Squeeze", &Squeeze, py::arg("axes") = std::vector<int8_t>({}),
py::arg("name") = "", py::arg("name") = "",
R"mydelimiter( R"mydelimiter(
Initialize a node containing a squeeze operator. Initialize a node containing a squeeze operator.
:param axes : axes to squeeze between [-r;r-1] :param axes: axes to squeeze between [-r;r-1]
with r = input_tensor.nbDims() with r = input_tensor.nbDims()
& r in [-128 , 127] & r in [-128 , 127]
:type axes : :py:class: List[Int] :type axes: :py:class: List[Int]
:param name : name of the node. :param name: name of the node.
)mydelimiter"); :type name: str
)mydelimiter");
} }
} // namespace Aidge } // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment