From f5212f52ce081611bd78dfd224c2fd4dd4ed263b Mon Sep 17 00:00:00 2001
From: Octave Perrin <operrin@lrtechnologies.fr>
Date: Fri, 29 Nov 2024 11:26:27 +0100
Subject: [PATCH] more python bind

---
 include/aidge/graph/GraphView.hpp         |  6 +-
 python_binding/graph/pybind_GraphView.cpp | 71 +++++++++++++++++++++--
 2 files changed, 69 insertions(+), 8 deletions(-)

diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp
index 921957e29..6d4a8cb28 100644
--- a/include/aidge/graph/GraphView.hpp
+++ b/include/aidge/graph/GraphView.hpp
@@ -296,9 +296,10 @@ public:
      * compatible with the selected kernel.
      * If not, add a Transpose Operator.
      * 4 - Propagate Tensor dimensions through the consecutive Operators.
-     @params string: backend Backend used, default is cpu
+     @params backend: Backend used, default is cpu
      @params Aidge Datatype: datatype used, default is float32
-     @params vector of vector of DimSize_t: dims
+     @params device: Device to be set
+     @params dims: vector of vector of DimSize_t: dims
      */
     void compile(const std::string& backend = "cpu",
                  const Aidge::DataType datatype = DataType::Float32,
@@ -313,6 +314,7 @@ public:
      * - Updates are made in node dependencies order, because if dims have changed
      *   at any point in the graph, it must de propagated correctly to all succeeding nodes;
      * - It handles cyclic dependencies correctly (currently only induced by the Memorize_Op).
+     * @param dims: vector of vector of dimensions of the inputs of the graphView
      * @return bool Whether it succeeded or not (failure can either raise an exception or return false)
      */
     bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false);
diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp
index 725cde2b3..291d4400e 100644
--- a/python_binding/graph/pybind_GraphView.cpp
+++ b/python_binding/graph/pybind_GraphView.cpp
@@ -241,7 +241,7 @@ void init_GraphView(py::module& m) {
           .def("get_nodes", &GraphView::getNodes,
           R"mydelimiter(
           Get the Nodes in the GraphView.
-          
+
           :return: List of the GraphView's Nodes
           :rtype: List[Node]
           )mydelimiter")
@@ -255,11 +255,70 @@ void init_GraphView(py::module& m) {
           :return: The Node of the GraphView with corresponding name
           :rtype: Node
           )mydelimiter")
-          .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false)
-          .def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"), py::arg("device") = 0, py::arg("dims")=std::vector<std::vector<DimSize_t>>())
-          .def("__call__", &GraphView::operator(), py::arg("connectors"))
-          .def("set_datatype", &GraphView::setDataType, py::arg("datatype"))
-          .def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0)
+
+          .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false,
+          R"mydelimiter(
+          Compute dimensions of input/output Tensors for each Node's
+          Operator of the GraphView.
+          This function verifies the following conditions:
+           - Every node's dimensions will be updated regardless of if dims were previously forwarded or not;
+           - Updates are made in node dependencies order, because if dims have changed
+             at any point in the graph, it must de propagated correctly to all succeeding nodes;
+           - It handles cyclic dependencies correctly (currently only induced by the Memorize_Op).
+          :param dims: List of list of dimensions of the inputs of the graphView
+          :type dims: List[List[int]]
+          :return: Whether it succeeded or not (failure can either raise an exception or return false)
+          :rtype: bool
+          )mydelimiter")
+
+          .def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"), py::arg("device") = 0, py::arg("dims")=std::vector<std::vector<DimSize_t>>(),
+          R"mydelimiter(
+          Assert Datatype, Backend, data format and dimensions along the GraphView are coherent.
+          If not, apply the required transformations.
+          Sets the GraphView ready for computation in four steps:
+          1 - Assert input Tensors' datatype is compatible with each Operator's datatype.
+          If not, a conversion Operator is inserted.
+          2 - Assert input Tensors' backend is compatible with each Operator's backend.
+          If not, add a Transmitter Operator.
+          3 - Assert data format (NCHW, NHWC, ...) of each Operator's input Tensor is
+          compatible with the selected kernel.
+          If not, add a Transpose Operator.
+          4 - Propagate Tensor dimensions through the consecutive Operators.
+          :param backend: Backend used, default is cpu
+          :type backend: string
+          :param datatype:  Aidge Datatype used, default is float32
+          :type datatype: DataType
+          :param device: Device to be set
+          :type device: int
+          :param dims: List of list of dimensions of the inputs of the graphView
+          :type dims: List[List[int]]
+          )mydelimiter")
+
+          .def("__call__", &GraphView::operator(), py::arg("connectors"),
+          R"mydelimiter(
+          Functional operator for user-friendly connection interface using an ordered set of Connectors.
+          :parm connectors: The connector be added to current connection
+          :type connectors: Connector
+          :return: The new connector
+          :rtype: Connector
+          )mydelimiter")
+
+          .def("set_datatype", &GraphView::setDataType, py::arg("datatype"),
+          R"mydelimiter(
+          Set the same data type for each Operator of the GraphView object's Nodes.
+          :param datatype: DataType to be set
+          :type datatype: DataType
+          )mydelimiter")
+
+          .def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0,
+          R"mydelimiter(
+          Set the same backend for each Operator of the GraphView object's Nodes.
+          :param backend: Backen used, default is cpu
+          :type backend: string
+          :param device: Device to be set
+          :type device: int
+          )mydelimiter")
+
           .def("get_ordered_nodes", &GraphView::getOrderedNodes, py::arg("reversed") = false,
                R"mydelimiter(
                Get a topological node order for an acyclic walk of the graph.
-- 
GitLab