From f5212f52ce081611bd78dfd224c2fd4dd4ed263b Mon Sep 17 00:00:00 2001 From: Octave Perrin <operrin@lrtechnologies.fr> Date: Fri, 29 Nov 2024 11:26:27 +0100 Subject: [PATCH] more python bind --- include/aidge/graph/GraphView.hpp | 6 +- python_binding/graph/pybind_GraphView.cpp | 71 +++++++++++++++++++++-- 2 files changed, 69 insertions(+), 8 deletions(-) diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 921957e29..6d4a8cb28 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -296,9 +296,10 @@ public: * compatible with the selected kernel. * If not, add a Transpose Operator. * 4 - Propagate Tensor dimensions through the consecutive Operators. - @params string: backend Backend used, default is cpu + @params backend: Backend used, default is cpu @params Aidge Datatype: datatype used, default is float32 - @params vector of vector of DimSize_t: dims + @params device: Device to be set + @params dims: vector of vector of DimSize_t: dims */ void compile(const std::string& backend = "cpu", const Aidge::DataType datatype = DataType::Float32, @@ -313,6 +314,7 @@ public: * - Updates are made in node dependencies order, because if dims have changed * at any point in the graph, it must de propagated correctly to all succeeding nodes; * - It handles cyclic dependencies correctly (currently only induced by the Memorize_Op). + * @param dims: vector of vector of dimensions of the inputs of the graphView * @return bool Whether it succeeded or not (failure can either raise an exception or return false) */ bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false); diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index 725cde2b3..291d4400e 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -241,7 +241,7 @@ void init_GraphView(py::module& m) { .def("get_nodes", &GraphView::getNodes, R"mydelimiter( Get the Nodes in the GraphView. - + :return: List of the GraphView's Nodes :rtype: List[Node] )mydelimiter") @@ -255,11 +255,70 @@ void init_GraphView(py::module& m) { :return: The Node of the GraphView with corresponding name :rtype: Node )mydelimiter") - .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false) - .def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"), py::arg("device") = 0, py::arg("dims")=std::vector<std::vector<DimSize_t>>()) - .def("__call__", &GraphView::operator(), py::arg("connectors")) - .def("set_datatype", &GraphView::setDataType, py::arg("datatype")) - .def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0) + + .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false, + R"mydelimiter( + Compute dimensions of input/output Tensors for each Node's + Operator of the GraphView. + This function verifies the following conditions: + - Every node's dimensions will be updated regardless of if dims were previously forwarded or not; + - Updates are made in node dependencies order, because if dims have changed + at any point in the graph, it must de propagated correctly to all succeeding nodes; + - It handles cyclic dependencies correctly (currently only induced by the Memorize_Op). + :param dims: List of list of dimensions of the inputs of the graphView + :type dims: List[List[int]] + :return: Whether it succeeded or not (failure can either raise an exception or return false) + :rtype: bool + )mydelimiter") + + .def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"), py::arg("device") = 0, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), + R"mydelimiter( + Assert Datatype, Backend, data format and dimensions along the GraphView are coherent. + If not, apply the required transformations. + Sets the GraphView ready for computation in four steps: + 1 - Assert input Tensors' datatype is compatible with each Operator's datatype. + If not, a conversion Operator is inserted. + 2 - Assert input Tensors' backend is compatible with each Operator's backend. + If not, add a Transmitter Operator. + 3 - Assert data format (NCHW, NHWC, ...) of each Operator's input Tensor is + compatible with the selected kernel. + If not, add a Transpose Operator. + 4 - Propagate Tensor dimensions through the consecutive Operators. + :param backend: Backend used, default is cpu + :type backend: string + :param datatype: Aidge Datatype used, default is float32 + :type datatype: DataType + :param device: Device to be set + :type device: int + :param dims: List of list of dimensions of the inputs of the graphView + :type dims: List[List[int]] + )mydelimiter") + + .def("__call__", &GraphView::operator(), py::arg("connectors"), + R"mydelimiter( + Functional operator for user-friendly connection interface using an ordered set of Connectors. + :parm connectors: The connector be added to current connection + :type connectors: Connector + :return: The new connector + :rtype: Connector + )mydelimiter") + + .def("set_datatype", &GraphView::setDataType, py::arg("datatype"), + R"mydelimiter( + Set the same data type for each Operator of the GraphView object's Nodes. + :param datatype: DataType to be set + :type datatype: DataType + )mydelimiter") + + .def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0, + R"mydelimiter( + Set the same backend for each Operator of the GraphView object's Nodes. + :param backend: Backen used, default is cpu + :type backend: string + :param device: Device to be set + :type device: int + )mydelimiter") + .def("get_ordered_nodes", &GraphView::getOrderedNodes, py::arg("reversed") = false, R"mydelimiter( Get a topological node order for an acyclic walk of the graph. -- GitLab