Skip to content
Snippets Groups Projects
Commit 108727b7 authored by Cyril Moineau's avatar Cyril Moineau Committed by Maxence Naud
Browse files

Update every binded operator to be registrable.

parent 299e7a40
No related branches found
No related tags found
No related merge requests found
...@@ -21,7 +21,7 @@ void init_Reshape(py::module& m) { ...@@ -21,7 +21,7 @@ void init_Reshape(py::module& m) {
py::class_<Reshape_Op, std::shared_ptr<Reshape_Op>, Attributes, OperatorTensor>(m, "ReshapeOp", py::multiple_inheritance()) py::class_<Reshape_Op, std::shared_ptr<Reshape_Op>, Attributes, OperatorTensor>(m, "ReshapeOp", py::multiple_inheritance())
.def("get_inputs_name", &Reshape_Op::getInputsName) .def("get_inputs_name", &Reshape_Op::getInputsName)
.def("get_outputs_name", &Reshape_Op::getOutputsName); .def("get_outputs_name", &Reshape_Op::getOutputsName);
declare_registrable<Reshape_Op>(m, "ReshapeOp");
m.def("Reshape", &Reshape, py::arg("shape"), py::arg("name") = ""); m.def("Reshape", &Reshape, py::arg("shape"), py::arg("name") = "");
} }
} // namespace Aidge } // namespace Aidge
...@@ -21,7 +21,7 @@ void init_Slice(py::module& m) { ...@@ -21,7 +21,7 @@ void init_Slice(py::module& m) {
py::class_<Slice_Op, std::shared_ptr<Slice_Op>, OperatorTensor>(m, "SliceOp", py::multiple_inheritance()) py::class_<Slice_Op, std::shared_ptr<Slice_Op>, OperatorTensor>(m, "SliceOp", py::multiple_inheritance())
.def("get_inputs_name", &Slice_Op::getInputsName) .def("get_inputs_name", &Slice_Op::getInputsName)
.def("get_outputs_name", &Slice_Op::getOutputsName); .def("get_outputs_name", &Slice_Op::getOutputsName);
declare_registrable<Slice_Op>(m, "SliceOp");
m.def("Slice", &Slice, py::arg("starts"), py::arg("ends"), py::arg("axes"), py::arg("name") = ""); m.def("Slice", &Slice, py::arg("starts"), py::arg("ends"), py::arg("axes"), py::arg("name") = "");
} }
} // namespace Aidge } // namespace Aidge
...@@ -23,7 +23,7 @@ void init_Softmax(py::module& m) { ...@@ -23,7 +23,7 @@ void init_Softmax(py::module& m) {
.def("get_inputs_name", &Softmax_Op::getInputsName) .def("get_inputs_name", &Softmax_Op::getInputsName)
.def("get_outputs_name", &Softmax_Op::getOutputsName) .def("get_outputs_name", &Softmax_Op::getOutputsName)
.def("attributes_name", &Softmax_Op::staticGetAttrsName); .def("attributes_name", &Softmax_Op::staticGetAttrsName);
declare_registrable<Softmax_Op>(m, "SoftmaxOp");
m.def("Softmax", &Softmax, py::arg("axis"), py::arg("name") = ""); m.def("Softmax", &Softmax, py::arg("axis"), py::arg("name") = "");
} }
} // namespace Aidge } // namespace Aidge
...@@ -21,7 +21,7 @@ void init_Sqrt(py::module& m) { ...@@ -21,7 +21,7 @@ void init_Sqrt(py::module& m) {
py::class_<Sqrt_Op, std::shared_ptr<Sqrt_Op>, OperatorTensor>(m, "SqrtOp", py::multiple_inheritance()) py::class_<Sqrt_Op, std::shared_ptr<Sqrt_Op>, OperatorTensor>(m, "SqrtOp", py::multiple_inheritance())
.def("get_inputs_name", &Sqrt_Op::getInputsName) .def("get_inputs_name", &Sqrt_Op::getInputsName)
.def("get_outputs_name", &Sqrt_Op::getOutputsName); .def("get_outputs_name", &Sqrt_Op::getOutputsName);
declare_registrable<Sqrt_Op>(m, "SqrtOp");
m.def("Sqrt", &Sqrt, py::arg("name") = ""); m.def("Sqrt", &Sqrt, py::arg("name") = "");
} }
} // namespace Aidge } // namespace Aidge
...@@ -21,7 +21,7 @@ void init_Sub(py::module& m) { ...@@ -21,7 +21,7 @@ void init_Sub(py::module& m) {
py::class_<Sub_Op, std::shared_ptr<Sub_Op>, OperatorTensor>(m, "SubOp", py::multiple_inheritance()) py::class_<Sub_Op, std::shared_ptr<Sub_Op>, OperatorTensor>(m, "SubOp", py::multiple_inheritance())
.def("get_inputs_name", &Sub_Op::getInputsName) .def("get_inputs_name", &Sub_Op::getInputsName)
.def("get_outputs_name", &Sub_Op::getOutputsName); .def("get_outputs_name", &Sub_Op::getOutputsName);
declare_registrable<Sub_Op>(m, "SubOp");
m.def("Sub", &Sub, py::arg("name") = ""); m.def("Sub", &Sub, py::arg("name") = "");
} }
} // namespace Aidge } // namespace Aidge
...@@ -27,12 +27,15 @@ namespace Aidge { ...@@ -27,12 +27,15 @@ namespace Aidge {
template <DimIdx_t DIM> template <DimIdx_t DIM>
void declare_Transpose(py::module &m) { void declare_Transpose(py::module &m) {
const std::string pyClassName("TransposeOp" + std::to_string(DIM) + "D");
py::class_<Transpose_Op<DIM>, std::shared_ptr<Transpose_Op<DIM>>, Attributes, OperatorTensor>( py::class_<Transpose_Op<DIM>, std::shared_ptr<Transpose_Op<DIM>>, Attributes, OperatorTensor>(
m, ("TransposeOp" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance()) m, ("TransposeOp" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance())
.def("get_inputs_name", &Transpose_Op<DIM>::getInputsName) .def("get_inputs_name", &Transpose_Op<DIM>::getInputsName)
.def("get_outputs_name", &Transpose_Op<DIM>::getOutputsName) .def("get_outputs_name", &Transpose_Op<DIM>::getOutputsName)
.def("attributes_name", &Transpose_Op<DIM>::staticGetAttrsName); .def("attributes_name", &Transpose_Op<DIM>::staticGetAttrsName);
declare_registrable<Transpose_Op<DIM>>(m, pyClassName);
m.def(("Transpose" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& output_dims_order, m.def(("Transpose" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& output_dims_order,
const std::string& name) { const std::string& name) {
AIDGE_ASSERT(output_dims_order.size() == DIM, "output_dims_order size [{}] does not match DIM [{}]", output_dims_order.size(), DIM); AIDGE_ASSERT(output_dims_order.size() == DIM, "output_dims_order size [{}] does not match DIM [{}]", output_dims_order.size(), DIM);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment