Forked from
Eclipse Projects / aidge / aidge_core
709 commits behind, 20 commits ahead of the upstream repository.
-
Octave Perrin authoredOctave Perrin authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
pybind_GraphView.cpp 9.80 KiB
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <memory>
#include <string>
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
#include "aidge/data/Data.hpp"
namespace py = pybind11;
namespace Aidge {
void init_GraphView(py::module& m) {
py::class_<GraphView, std::shared_ptr<GraphView>>(m, "GraphView")
.def(py::init<>())
.def("save", &GraphView::save, py::arg("path"), py::arg("verbose") = false, py::arg("show_producers") = true,
R"mydelimiter(
Save the GraphView as a Mermaid graph in a .md file at the specified location.
:param path: save location
:type path: str
)mydelimiter")
.def("inputs", (std::vector<std::pair<NodePtr, IOIndex_t>> (GraphView::*)() const) &GraphView::inputs)
.def("outputs", (std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> (GraphView::*)() const) &GraphView::outputs)
.def("in_view", (bool (GraphView::*)(const NodePtr&) const) &GraphView::inView)
.def("in_view", (bool (GraphView::*)(const std::string&) const) &GraphView::inView)
.def("root_node", &GraphView::rootNode)
.def("set_root_node", &GraphView::setRootNode, py::arg("node"))
.def("__repr__", &GraphView::repr)
.def("__len__", [](const GraphView& g){ return g.getNodes().size(); })
.def("log_outputs", &GraphView::logOutputs, py::arg("path"))
.def("get_ordered_inputs", &GraphView::getOrderedInputs)
.def("get_ordered_outputs", &GraphView::getOrderedOutputs)
.def("get_output_nodes", &GraphView::outputNodes,
R"mydelimiter(
Get set of output Nodes.
:rtype: list[Node]
)mydelimiter")
.def("get_input_nodes", &GraphView::inputNodes,
R"mydelimiter(
Get set of input Nodes.
:rtype: list[Node]
)mydelimiter")
.def("set_ordered_inputs", &GraphView::setOrderedInputs, py::arg("inputs"),
R"mydelimiter(
Orders the inputs of the GraphView
The inputs will be ordered in the same order as they come in the std::vector.
Inputs missing from this vector will then be added as per their previous order.
:param inputs: set of inputs in the wanted order
:type inputs: List[(Node, int)]
)mydelimiter")
.def("set_ordered_outputs", &GraphView::setOrderedOutputs, py::arg("outputs"),
R"mydelimiter(
Orders the outputs of the GraphView
The outputs will be ordered in the same order as they come in the std::vector.
Outputs missing from this vector will then be added as per their previous order.
:param outputs: set of outputs in the wanted order
:type outputs: List[(Node, int)]
)mydelimiter")
.def("add", (void (GraphView::*)(std::shared_ptr<Node>, bool)) & GraphView::add,
py::arg("other_node"), py::arg("include_learnable_parameters") = true,
R"mydelimiter(
Include a Node to the current GraphView object.
:param other_node: Node to add
:type other_node: Node
:param include_learnable_parameters: include non-data inputs, like weights and biases, default True.
:type include_learnable_parameters: bool, optional
)mydelimiter")
.def("add", (bool (GraphView::*)(std::shared_ptr<GraphView>, bool)) & GraphView::add,
py::arg("other_graph"), py::arg("include_learnable_parameters") = true,
R"mydelimiter(
Include a GraphView to the current GraphView object.
:param other_graph: GraphView to add
:type other_graph: GraphView
:param include_learnable_parameters: include non-data inputs, like weights and biases, default True.
:type include_learnable_parameters: bool, optional
)mydelimiter")
.def("add_child",
(void (GraphView::*)(std::shared_ptr<Node>,
std::shared_ptr<Node>,
const IOIndex_t,
IOIndex_t)) &
GraphView::addChild,
py::arg("to_other_node"), py::arg("from_out_node") = nullptr,
py::arg("from_tensor") = 0U, py::arg("to_tensor") = gk_IODefaultIndex,
R"mydelimiter(
Include a Node to the current GraphView object.
:param to_other_node: Node to add
:type to_other_node: Node
:param from_out_node: Node inside the GraphView the new Node will be linked to (it will become a parent of the new Node). If the GraphView only has one output Node, then default to this Node.
:type from_out_node: Node
:param from_tensor: Ouput Tensor ID of the already included Node. Default to 0.
:type from_tensor: int
:param to_tensor: Input Tensor ID of the new Node. Default to gk_IODefaultIndex, meaning first available data input for the Node.
:type to_tensor: int
)mydelimiter")
.def_static("replace", py::overload_cast<const std::shared_ptr<GraphView>&, const std::shared_ptr<GraphView>&>(&GraphView::replace), py::arg("old_graph"), py::arg("new_graph"),
R"mydelimiter(
Replace the old set of Nodes in a GraphView with the new set of given Nodes in a GraphView if possible in every GraphView.
:param old_graph: GraphView of Nodes actually connected in GraphViews.
:type old_graph: GraphView
:param new_graph: GraphView of Nodes with inner connections already taken care of.
:type new_graph: GraphView
:return: Whether any replacement has been made.
:rtype: bool
)mydelimiter")
.def_static("replace", py::overload_cast<const std::set<NodePtr>&, const std::set<NodePtr>&>(&GraphView::replace), py::arg("old_nodes"), py::arg("new_nodes"),
R"mydelimiter(
Replace the old set of Nodes with the new set of given Nodes if possible in every GraphView.
:param old_nodes: Nodes actually connected in GraphViews.
:type old_nodes: Node
:param new_nodes: Nodes with inner connections already taken care of.
:type new_nodes: Node
:return: Whether any replacement has been made.
:rtype: bool
)mydelimiter")
.def("clone", &GraphView::clone,
R"mydelimiter(
Clone the current GraphView using a callback function for the Node cloning, allowing to specify how each
Node should be cloned or replaced by another Node type, or removed (i.e. replaced by identity).
When a Node is removed, the clone() method automatically finds the next valid parent in line, going backward in
the graph and connects it if that makes sense without ambiguity (effectively treating the removed Node as an
identity operation).
:param cloneNode Callback function to clone a node
:type cloneNode Node
:return: Cloned GraphView
:rtype: GraphView
)mydelimiter")
.def("get_nodes", &GraphView::getNodes,
R"mydelimiter(
Get the Nodes in the GraphView.
:return: List of the GraphView's Nodes
:rtype: List[Node]
)mydelimiter")
.def("get_node", &GraphView::getNode, py::arg("node_name"),
R"mydelimiter(
Get the Node with the corresponding name if it is in the GraphView.
:param node_name The name of the Node
:type string
:return: The Node of the GraphView with corresponding name
:rtype: Node
)mydelimiter")
.def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false)
.def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"), py::arg("device") = 0, py::arg("dims")=std::vector<std::vector<DimSize_t>>())
.def("__call__", &GraphView::operator(), py::arg("connectors"))
.def("set_datatype", &GraphView::setDataType, py::arg("datatype"))
.def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0)
.def("get_ordered_nodes", &GraphView::getOrderedNodes, py::arg("reversed") = false,
R"mydelimiter(
Get ordered nodes for the graph view
)mydelimiter")
// .def("__getitem__", [](Tensor& b, size_t idx)-> py::object {
// // TODO : Should return error if backend not compatible with get
// if (idx >= b.size()) throw py::index_error();
// switch(b.dataType()){
// case DataType::Float32:
// return py::cast(static_cast<float*>(b.getImpl()->rawPtr())[idx]);
// case DataType::Int32:
// return py::cast(static_cast<int*>(b.getImpl()->rawPtr())[idx]);
// default:
// return py::none();
// }
// })
.def("get_ranked_nodes", &GraphView::getRankedNodes)
.def("get_ranked_nodes_name", &GraphView::getRankedNodesName, py::arg("format"), py::arg("mark_non_unicity") = true)
.def("set_dataformat", &GraphView::setDataFormat, py::arg("dataformat"))
;
m.def("get_connected_graph_view", &getConnectedGraphView);
}
} // namespace Aidge