diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 666483bf8e9429708aad25abb0c6fade59fd784c..a760b14ff85ebf27a1d94bc6a265da9f9830f84e 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -190,7 +190,7 @@ public: /** * @brief Orders the inputs of the GraphView - * @details The Inputs will be ordered in the same order as they come in the std::vector. + * @details The inputs will be ordered in the same order as they come in the std::vector. * Inputs missing from this vector will then be added as per their previous order. * @param std::vector<std::pair<NodePtr, IOIndex_t>>& inputs set of inputs in the wanted order */ @@ -362,7 +362,7 @@ public: const NodePtr otherNode) const; // TODO change it for a vector<vector> ? /** - * @brief Get the Nodes pointed to by the GraphView object. + * @brief Get the Nodes in the GraphView. * @return std::set<NodePtr> */ inline const std::set<NodePtr>& getNodes() const noexcept { return mNodes; } diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index 4b9d2ad545c47971b7c0dff029585bb4c9ae5638..7edb4adcf2fbf552cbbc344eaedc721910d52f27 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -55,8 +55,23 @@ void init_GraphView(py::module& m) { :rtype: list[Node] )mydelimiter") - .def("set_ordered_inputs", &GraphView::setOrderedInputs, py::arg("inputs")) - .def("set_ordered_outputs", &GraphView::setOrderedOutputs, py::arg("outputs")) + .def("set_ordered_inputs", &GraphView::setOrderedInputs, py::arg("inputs"), + R"mydelimiter( + Orders the inputs of the GraphView + The inputs will be ordered in the same order as they come in the std::vector. + Inputs missing from this vector will then be added as per their previous order. + :param inputs: set of inputs in the wanted order + :type inputs: List[(Node, int)] + )mydelimiter") + + .def("set_ordered_outputs", &GraphView::setOrderedOutputs, py::arg("outputs"), + R"mydelimiter( + Orders the outputs of the GraphView + The outputs will be ordered in the same order as they come in the std::vector. + Outputs missing from this vector will then be added as per their previous order. + :param outputs: set of outputs in the wanted order + :type outputs: List[(Node, int)] + )mydelimiter") .def("add", (void (GraphView::*)(std::shared_ptr<Node>, bool)) & GraphView::add, py::arg("other_node"), py::arg("include_learnable_parameters") = true, @@ -124,9 +139,32 @@ void init_GraphView(py::module& m) { :return: Whether any replacement has been made. :rtype: bool )mydelimiter") - .def("clone", &GraphView::clone) - .def("get_nodes", &GraphView::getNodes) - .def("get_node", &GraphView::getNode, py::arg("node_name")) + .def("clone", &GraphView::clone, + R"mydelimiter( + Clone the current GraphView using a callback function for the Node cloning, allowing to specify how each + Node should be cloned or replaced by another Node type, or removed (i.e. replaced by identity). + When a Node is removed, the clone() method automatically finds the next valid parent in line, going backward in + the graph and connects it if that makes sense without ambiguity (effectively treating the removed Node as an + identity operation). + :param: cloneNode Callback function to clone a node + :type: cloneNode Node + :return: Cloned GraphView + :rtype: GraphView + )mydelimiter") + .def("get_nodes", &GraphView::getNodes, + R"mydelimiter( + Get the Nodes in the GraphView. + :return: List of the GraphView's Nodes + :rtype: List[Node] + )mydelimiter") + .def("get_node", &GraphView::getNode, py::arg("node_name"), + R"mydelimiter( + Get the Node with the corresponding name if it is in the GraphView. + :param: node_name The name of the Node + :type: string + :return: The Node of the GraphView with corresponding name + :rtype: Node + )mydelimiter") .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false) .def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"), py::arg("device") = 0, py::arg("dims")=std::vector<std::vector<DimSize_t>>()) .def("__call__", &GraphView::operator(), py::arg("connectors"))