From 32bdd10492f9996598cfbaeca01458c3e3c2307c Mon Sep 17 00:00:00 2001 From: Octave Perrin <operrin@lrtechnologies.fr> Date: Mon, 18 Nov 2024 14:38:34 +0100 Subject: [PATCH] python binds --- include/aidge/graph/GraphView.hpp | 4 +- python_binding/graph/pybind_GraphView.cpp | 48 ++++++++++++++++++++--- 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 666483bf8..a760b14ff 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -190,7 +190,7 @@ public: /** * @brief Orders the inputs of the GraphView - * @details The Inputs will be ordered in the same order as they come in the std::vector. + * @details The inputs will be ordered in the same order as they come in the std::vector. * Inputs missing from this vector will then be added as per their previous order. * @param std::vector<std::pair<NodePtr, IOIndex_t>>& inputs set of inputs in the wanted order */ @@ -362,7 +362,7 @@ public: const NodePtr otherNode) const; // TODO change it for a vector<vector> ? /** - * @brief Get the Nodes pointed to by the GraphView object. + * @brief Get the Nodes in the GraphView. * @return std::set<NodePtr> */ inline const std::set<NodePtr>& getNodes() const noexcept { return mNodes; } diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index 4b9d2ad54..7edb4adcf 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -55,8 +55,23 @@ void init_GraphView(py::module& m) { :rtype: list[Node] )mydelimiter") - .def("set_ordered_inputs", &GraphView::setOrderedInputs, py::arg("inputs")) - .def("set_ordered_outputs", &GraphView::setOrderedOutputs, py::arg("outputs")) + .def("set_ordered_inputs", &GraphView::setOrderedInputs, py::arg("inputs"), + R"mydelimiter( + Orders the inputs of the GraphView + The inputs will be ordered in the same order as they come in the std::vector. + Inputs missing from this vector will then be added as per their previous order. + :param inputs: set of inputs in the wanted order + :type inputs: List[(Node, int)] + )mydelimiter") + + .def("set_ordered_outputs", &GraphView::setOrderedOutputs, py::arg("outputs"), + R"mydelimiter( + Orders the outputs of the GraphView + The outputs will be ordered in the same order as they come in the std::vector. + Outputs missing from this vector will then be added as per their previous order. + :param outputs: set of outputs in the wanted order + :type outputs: List[(Node, int)] + )mydelimiter") .def("add", (void (GraphView::*)(std::shared_ptr<Node>, bool)) & GraphView::add, py::arg("other_node"), py::arg("include_learnable_parameters") = true, @@ -124,9 +139,32 @@ void init_GraphView(py::module& m) { :return: Whether any replacement has been made. :rtype: bool )mydelimiter") - .def("clone", &GraphView::clone) - .def("get_nodes", &GraphView::getNodes) - .def("get_node", &GraphView::getNode, py::arg("node_name")) + .def("clone", &GraphView::clone, + R"mydelimiter( + Clone the current GraphView using a callback function for the Node cloning, allowing to specify how each + Node should be cloned or replaced by another Node type, or removed (i.e. replaced by identity). + When a Node is removed, the clone() method automatically finds the next valid parent in line, going backward in + the graph and connects it if that makes sense without ambiguity (effectively treating the removed Node as an + identity operation). + :param: cloneNode Callback function to clone a node + :type: cloneNode Node + :return: Cloned GraphView + :rtype: GraphView + )mydelimiter") + .def("get_nodes", &GraphView::getNodes, + R"mydelimiter( + Get the Nodes in the GraphView. + :return: List of the GraphView's Nodes + :rtype: List[Node] + )mydelimiter") + .def("get_node", &GraphView::getNode, py::arg("node_name"), + R"mydelimiter( + Get the Node with the corresponding name if it is in the GraphView. + :param: node_name The name of the Node + :type: string + :return: The Node of the GraphView with corresponding name + :rtype: Node + )mydelimiter") .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false) .def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"), py::arg("device") = 0, py::arg("dims")=std::vector<std::vector<DimSize_t>>()) .def("__call__", &GraphView::operator(), py::arg("connectors")) -- GitLab