From de782ae3ab41b8f9a2fcbd29e0245aabf89d144c Mon Sep 17 00:00:00 2001 From: Vincent TEMPLIER <vincent.templier@cea.fr> Date: Wed, 24 Apr 2024 13:44:00 +0000 Subject: [PATCH] Add Python binding for Expand MetaOps recipe --- python_binding/recipes/pybind_Recipes.cpp | 32 +++++++++++++---------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/python_binding/recipes/pybind_Recipes.cpp b/python_binding/recipes/pybind_Recipes.cpp index f122c4116..b85d1c41e 100644 --- a/python_binding/recipes/pybind_Recipes.cpp +++ b/python_binding/recipes/pybind_Recipes.cpp @@ -21,66 +21,70 @@ namespace py = pybind11; namespace Aidge { -void init_Recipes(py::module &m) { +void init_Recipes(py::module &m) +{ m.def("fuse_mul_add", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseMulAdd), py::arg("graph_view"), R"mydelimiter( - Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. + Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. - :param graph_view: Graph view on which we want to apply the recipie + :param graph_view: Graph view on which we want to apply the recipe :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); // m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( - // Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. + // recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. // :param nodes: The MatMul and Add nodes to fuse. // :type nodes: list of :py:class:`aidge_core.Node` // )mydelimiter"); m.def("remove_dropout",static_cast<void(*)(std::shared_ptr<GraphView>)>(removeDropout), py::arg("graph_view"), R"mydelimiter( - Recipie to remove a dropout operator. + Recipe to remove a dropout operator. - :param graph_view: Graph view on which we want to apply the recipie + :param graph_view: Graph view on which we want to apply the recipe :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter( - Recipie to remove a flatten operator. + Recipe to remove a flatten operator. - :param graph_view: Graph view on which we want to apply the recipie + :param graph_view: Graph view on which we want to apply the recipe :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); // m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter( - // Recipie to remove a flatten operator. + // Recipe to remove a flatten operator. // :param nodes: The flatten operator to remove. // :type nodes: list of :py:class:`aidge_core.Node` // )mydelimiter"); // m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( - // Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. + // Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. // :param nodes: The MatMul and Add nodes to fuse. // :type nodes: list of :py:class:`aidge_core.Node` // )mydelimiter"); m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter( - Recipie to remove a flatten operator. + Recipe to remove a flatten operator. - :param graph_view: Graph view on which we want to apply the recipie + :param graph_view: Graph view on which we want to apply the recipe :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); - m.def("get_conv_horizontal_tiling", static_cast<std::set<std::shared_ptr<Node>>(*)(const std::shared_ptr<Node>&, const DimIdx_t, const std::size_t)>(getConvHorizontalTiling), + m.def("get_conv_horizontal_tiling", static_cast<std::set<std::shared_ptr<Node>>(*)(const std::shared_ptr<Node>&, const DimIdx_t, const std::size_t)>(getConvHorizontalTiling), py::arg("node"), py::arg("axis"), py::arg("nb_slices")); // m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter( - // Recipie to remove a flatten operator. + // recipe to remove a flatten operator. // :param nodes: The flatten operator to remove. // :type nodes: list of :py:class:`aidge_core.Node` // )mydelimiter"); + + m.def("expand_metaops", static_cast<void(*)(std::shared_ptr<GraphView>, bool)>(expandMetaOps), py::arg("graph_view"), py::arg("recursive") = false); } + } // namespace Aidge -- GitLab