Newer
Older
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <memory>
#include <string>
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
#include "aidge/data/Data.hpp"
namespace py = pybind11;
namespace Aidge {
void init_GraphView(py::module& m) {
py::class_<GraphView, std::shared_ptr<GraphView>>(m, "GraphView")
.def(py::init<>())
.def("save", &GraphView::save, py::arg("path"), py::arg("verbose") = false,
R"mydelimiter(
Save the GraphView as a Mermaid graph in a .md file at the specified location.
:param path: save location
:type path: str
)mydelimiter")
.def("get_output_nodes", &GraphView::outputNodes,
R"mydelimiter(
Get set of output Nodes.
:rtype: list[Node]
)mydelimiter")
.def("get_input_nodes", &GraphView::inputNodes,
R"mydelimiter(
Get set of input Nodes.
:rtype: list[Node]
)mydelimiter")
.def("add", (void (GraphView::*)(std::shared_ptr<Node>, bool)) & GraphView::add,
py::arg("other_node"), py::arg("include_learnable_parameters") = true,
R"mydelimiter(
Include a Node to the current GraphView object.
:param other_node: Node to add
:type other_node: Node
:param include_learnable_parameters: include non-data inputs, like weights and biases, default True.
:type include_learnable_parameters: bool, optional
)mydelimiter")
.def("add", (bool (GraphView::*)(std::shared_ptr<GraphView>)) & GraphView::add,
py::arg("other_graph"),
R"mydelimiter(
Include a GraphView to the current GraphView object.
:param other_graph: GraphView to add
)mydelimiter")
.def("add_child",
(void (GraphView::*)(std::shared_ptr<Node>,
std::shared_ptr<Node>,
const IOIndex_t,
IOIndex_t)) &
GraphView::addChild,
py::arg("toOtherNode"), py::arg("fromOutNode") = nullptr,
py::arg("fromTensor") = 0U, py::arg("toTensor") = gk_IODefaultIndex,
R"mydelimiter(
Include a Node to the current GraphView object.
:param other_node: Node to add
:type oth_Node: Node
:param includeLearnableParameter: include non-data inputs, like weights and biases. Default True.
:type includeLearnableParameter
)mydelimiter")
.def_static("replace", &GraphView::replace, py::arg("old_nodes"), py::arg("new_nodes"),
Replace the old set of Nodes with the new set of given Nodes if possible in every GraphView.
:param old_nodes: Nodes actually connected in GraphViews.
:type old_nodes: Node
:param new_nodes: Nodes with inner connections already taken care of.
:type new_nodes: Node
:return: Whether any replacement has been made.
:rtype: bool
)mydelimiter")
.def("get_nodes", &GraphView::getNodes)
.def("get_node", &GraphView::getNode, py::arg("node_name"))
.def("forward_dims", &GraphView::forwardDims)
.def("__call__", &GraphView::operator(), py::arg("connectors"))
.def("set_datatype", &GraphView::setDataType, py::arg("datatype"))
.def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0)
// .def("__getitem__", [](Tensor& b, size_t idx)-> py::object {
// // TODO : Should return error if backend not compatible with get
// if (idx >= b.size()) throw py::index_error();
// switch(b.dataType()){
// case DataType::Float32:
// return py::cast(static_cast<float*>(b.getImpl()->rawPtr())[idx]);
// case DataType::Int32:
// return py::cast(static_cast<int*>(b.getImpl()->rawPtr())[idx]);
// default:
// return py::none();
// }
// })
;
}