Skip to content
Snippets Groups Projects
Commit f5212f52 authored by Octave Perrin's avatar Octave Perrin
Browse files

more python bind

parent 0673c69b
No related branches found
No related tags found
1 merge request!245aidge_core#194: Add documentation for GraphView
This commit is part of merge request !245. Comments created here will be created in the context of that merge request.
...@@ -296,9 +296,10 @@ public: ...@@ -296,9 +296,10 @@ public:
* compatible with the selected kernel. * compatible with the selected kernel.
* If not, add a Transpose Operator. * If not, add a Transpose Operator.
* 4 - Propagate Tensor dimensions through the consecutive Operators. * 4 - Propagate Tensor dimensions through the consecutive Operators.
@params string: backend Backend used, default is cpu @params backend: Backend used, default is cpu
@params Aidge Datatype: datatype used, default is float32 @params Aidge Datatype: datatype used, default is float32
@params vector of vector of DimSize_t: dims @params device: Device to be set
@params dims: vector of vector of DimSize_t: dims
*/ */
void compile(const std::string& backend = "cpu", void compile(const std::string& backend = "cpu",
const Aidge::DataType datatype = DataType::Float32, const Aidge::DataType datatype = DataType::Float32,
...@@ -313,6 +314,7 @@ public: ...@@ -313,6 +314,7 @@ public:
* - Updates are made in node dependencies order, because if dims have changed * - Updates are made in node dependencies order, because if dims have changed
* at any point in the graph, it must de propagated correctly to all succeeding nodes; * at any point in the graph, it must de propagated correctly to all succeeding nodes;
* - It handles cyclic dependencies correctly (currently only induced by the Memorize_Op). * - It handles cyclic dependencies correctly (currently only induced by the Memorize_Op).
* @param dims: vector of vector of dimensions of the inputs of the graphView
* @return bool Whether it succeeded or not (failure can either raise an exception or return false) * @return bool Whether it succeeded or not (failure can either raise an exception or return false)
*/ */
bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false); bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false);
......
...@@ -241,7 +241,7 @@ void init_GraphView(py::module& m) { ...@@ -241,7 +241,7 @@ void init_GraphView(py::module& m) {
.def("get_nodes", &GraphView::getNodes, .def("get_nodes", &GraphView::getNodes,
R"mydelimiter( R"mydelimiter(
Get the Nodes in the GraphView. Get the Nodes in the GraphView.
:return: List of the GraphView's Nodes :return: List of the GraphView's Nodes
:rtype: List[Node] :rtype: List[Node]
)mydelimiter") )mydelimiter")
...@@ -255,11 +255,70 @@ void init_GraphView(py::module& m) { ...@@ -255,11 +255,70 @@ void init_GraphView(py::module& m) {
:return: The Node of the GraphView with corresponding name :return: The Node of the GraphView with corresponding name
:rtype: Node :rtype: Node
)mydelimiter") )mydelimiter")
.def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false)
.def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"), py::arg("device") = 0, py::arg("dims")=std::vector<std::vector<DimSize_t>>()) .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false,
.def("__call__", &GraphView::operator(), py::arg("connectors")) R"mydelimiter(
.def("set_datatype", &GraphView::setDataType, py::arg("datatype")) Compute dimensions of input/output Tensors for each Node's
.def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0) Operator of the GraphView.
This function verifies the following conditions:
- Every node's dimensions will be updated regardless of if dims were previously forwarded or not;
- Updates are made in node dependencies order, because if dims have changed
at any point in the graph, it must de propagated correctly to all succeeding nodes;
- It handles cyclic dependencies correctly (currently only induced by the Memorize_Op).
:param dims: List of list of dimensions of the inputs of the graphView
:type dims: List[List[int]]
:return: Whether it succeeded or not (failure can either raise an exception or return false)
:rtype: bool
)mydelimiter")
.def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"), py::arg("device") = 0, py::arg("dims")=std::vector<std::vector<DimSize_t>>(),
R"mydelimiter(
Assert Datatype, Backend, data format and dimensions along the GraphView are coherent.
If not, apply the required transformations.
Sets the GraphView ready for computation in four steps:
1 - Assert input Tensors' datatype is compatible with each Operator's datatype.
If not, a conversion Operator is inserted.
2 - Assert input Tensors' backend is compatible with each Operator's backend.
If not, add a Transmitter Operator.
3 - Assert data format (NCHW, NHWC, ...) of each Operator's input Tensor is
compatible with the selected kernel.
If not, add a Transpose Operator.
4 - Propagate Tensor dimensions through the consecutive Operators.
:param backend: Backend used, default is cpu
:type backend: string
:param datatype: Aidge Datatype used, default is float32
:type datatype: DataType
:param device: Device to be set
:type device: int
:param dims: List of list of dimensions of the inputs of the graphView
:type dims: List[List[int]]
)mydelimiter")
.def("__call__", &GraphView::operator(), py::arg("connectors"),
R"mydelimiter(
Functional operator for user-friendly connection interface using an ordered set of Connectors.
:parm connectors: The connector be added to current connection
:type connectors: Connector
:return: The new connector
:rtype: Connector
)mydelimiter")
.def("set_datatype", &GraphView::setDataType, py::arg("datatype"),
R"mydelimiter(
Set the same data type for each Operator of the GraphView object's Nodes.
:param datatype: DataType to be set
:type datatype: DataType
)mydelimiter")
.def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0,
R"mydelimiter(
Set the same backend for each Operator of the GraphView object's Nodes.
:param backend: Backen used, default is cpu
:type backend: string
:param device: Device to be set
:type device: int
)mydelimiter")
.def("get_ordered_nodes", &GraphView::getOrderedNodes, py::arg("reversed") = false, .def("get_ordered_nodes", &GraphView::getOrderedNodes, py::arg("reversed") = false,
R"mydelimiter( R"mydelimiter(
Get a topological node order for an acyclic walk of the graph. Get a topological node order for an acyclic walk of the graph.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment