Skip to content
Snippets Groups Projects
pybind_GraphView.cpp 4.67 KiB
Newer Older
Cyril Moineau's avatar
Cyril Moineau committed
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <memory>
#include <string>
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
#include "aidge/data/Data.hpp"
Cyril Moineau's avatar
Cyril Moineau committed

namespace py = pybind11;
namespace Aidge {
void init_GraphView(py::module& m) {
    py::class_<GraphView, std::shared_ptr<GraphView>>(m, "GraphView")
          .def(py::init<>())
          .def("save", &GraphView::save, py::arg("path"), py::arg("verbose") = false,
          R"mydelimiter(
          Save the GraphView as a Mermaid graph in a .md file at the specified location.
          :param path: save location
          :type path: str
          )mydelimiter")

          .def("get_output_nodes", &GraphView::outputNodes,
          R"mydelimiter(
          Get set of output Nodes.
          :rtype: list[Node]
          )mydelimiter")

          .def("get_input_nodes", &GraphView::inputNodes,
          R"mydelimiter(
          Get set of input Nodes.
          :rtype: list[Node]
          )mydelimiter")

          .def("add", (void (GraphView::*)(std::shared_ptr<Node>, bool)) & GraphView::add,
               py::arg("other_node"), py::arg("include_learnable_parameters") = true,
          R"mydelimiter(
          Include a Node to the current GraphView object.
          :param other_node: Node to add
          :type other_node: Node
          :param include_learnable_parameters: include non-data inputs, like weights and biases, default True.
          :type include_learnable_parameters: bool, optional
          )mydelimiter")

Olivier BICHLER's avatar
Olivier BICHLER committed
          .def("add", (bool (GraphView::*)(std::shared_ptr<GraphView>)) & GraphView::add,
               py::arg("other_graph"),
          R"mydelimiter(
          Include a GraphView to the current GraphView object.

          :param other_graph: GraphView to add
          :type other_graph: GraphView
          )mydelimiter")

          .def("add_child",
               (void (GraphView::*)(std::shared_ptr<Node>,
                                   std::shared_ptr<Node>,
                                   const IOIndex_t,
                                   IOIndex_t)) &
                    GraphView::addChild,
               py::arg("toOtherNode"), py::arg("fromOutNode") = nullptr,
Maxence Naud's avatar
Maxence Naud committed
               py::arg("fromTensor") = 0U, py::arg("toTensor") = gk_IODefaultIndex,
          R"mydelimiter(
          Include a Node to the current GraphView object.
Maxence Naud's avatar
Maxence Naud committed
          :param other_node: Node to add
          :type oth_Node: Node
          :param includeLearnableParameter: include non-data inputs, like weights and biases. Default True.
          :type includeLearnableParameter
          )mydelimiter")
Maxence Naud's avatar
Maxence Naud committed
          .def_static("replace", &GraphView::replace, py::arg("old_nodes"), py::arg("new_nodes"),
Maxence Naud's avatar
Maxence Naud committed
          R"mydelimiter(
          Replace the old set of Nodes with the new set of given Nodes if possible in every GraphView.

          :param old_nodes: Nodes actually connected in GraphViews.
          :type old_nodes: Node
          :param new_nodes: Nodes with inner connections already taken care of.
Maxence Naud's avatar
Maxence Naud committed
          :type new_nodes: Node
          :return: Whether any replacement has been made.
          :rtype: bool
          )mydelimiter")

          .def("get_nodes", &GraphView::getNodes)
          .def("get_node", &GraphView::getNode, py::arg("node_name"))
          .def("forward_dims", &GraphView::forwardDims)
          .def("__call__", &GraphView::operator(), py::arg("connectors"))
Maxence Naud's avatar
Maxence Naud committed
          .def("set_datatype", &GraphView::setDataType, py::arg("datatype"))
Olivier BICHLER's avatar
Olivier BICHLER committed
          .def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0)
Cyril Moineau's avatar
Cyril Moineau committed
          //   .def("__getitem__", [](Tensor& b, size_t idx)-> py::object {
          //      // TODO : Should return error if backend not compatible with get
          //      if (idx >= b.size()) throw py::index_error();
          //      switch(b.dataType()){
          //           case DataType::Float32:
          //                return py::cast(static_cast<float*>(b.getImpl()->rawPtr())[idx]);
          //           case DataType::Int32:
          //                return py::cast(static_cast<int*>(b.getImpl()->rawPtr())[idx]);
          //           default:
          //                return py::none();
          //           }
          //      })
            ;
}
}  // namespace Aidge