From 593039d5efb7a35b629b47661c1dc371c6efbfdc Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Tue, 12 Dec 2023 10:24:20 +0000 Subject: [PATCH] [Add] some functions python binding --- python_binding/graph/pybind_GraphView.cpp | 1 + python_binding/graph/pybind_Node.cpp | 2 ++ python_binding/operator/pybind_Operator.cpp | 1 + python_binding/recipies/pybind_Recipies.cpp | 9 +++++++-- python_binding/scheduler/pybind_Scheduler.cpp | 1 + 5 files changed, 12 insertions(+), 2 deletions(-) diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index 61392470a..12a007ab9 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -97,6 +97,7 @@ void init_GraphView(py::module& m) { .def("get_nodes", &GraphView::getNodes) .def("get_node", &GraphView::getNode, py::arg("node_name")) .def("forward_dims", &GraphView::forwardDims) + .def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype")) .def("__call__", &GraphView::operator(), py::arg("connectors")) .def("set_datatype", &GraphView::setDataType, py::arg("datatype")) .def("set_backend", &GraphView::setBackend, py::arg("backend")) diff --git a/python_binding/graph/pybind_Node.cpp b/python_binding/graph/pybind_Node.cpp index aa5c21372..1f655b50a 100644 --- a/python_binding/graph/pybind_Node.cpp +++ b/python_binding/graph/pybind_Node.cpp @@ -137,6 +137,8 @@ void init_Node(py::module& m) { :rtype: int )mydelimiter") + .def("get_parent", &Node::getParent, py::arg("in_id")) + .def("get_parents", &Node::getParents, R"mydelimiter( Get parents. diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp index f9482eda2..89d864ec9 100644 --- a/python_binding/operator/pybind_Operator.cpp +++ b/python_binding/operator/pybind_Operator.cpp @@ -20,6 +20,7 @@ namespace Aidge { void init_Operator(py::module& m){ py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator") .def("set_output", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setOutput), py::arg("outputIdx"), py::arg("data")) + .def("set_input", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setInput), py::arg("inputIdx"), py::arg("data")) .def("get_raw_output", &Operator::getRawOutput, py::arg("outputIdx")) .def("set_input", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setInput), py::arg("inputIdx"), py::arg("data")) .def("get_raw_input", &Operator::getRawInput, py::arg("inputIdx")) diff --git a/python_binding/recipies/pybind_Recipies.cpp b/python_binding/recipies/pybind_Recipies.cpp index 820b6e12b..8c0c66ec1 100644 --- a/python_binding/recipies/pybind_Recipies.cpp +++ b/python_binding/recipies/pybind_Recipies.cpp @@ -12,9 +12,11 @@ #include <pybind11/pybind11.h> #include <pybind11/stl.h> +#include <cstddef> #include <string> #include "aidge/recipies/Recipies.hpp" +#include "aidge/utils/Types.h" namespace py = pybind11; @@ -28,7 +30,7 @@ void init_Recipies(py::module &m) { :param graph_view: Graph view on which we want to apply the recipie :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); - + // m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( // Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. @@ -63,7 +65,10 @@ void init_Recipies(py::module &m) { :param graph_view: Graph view on which we want to apply the recipie :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); - + + m.def("get_conv_horizontal_tiling", static_cast<std::set<std::shared_ptr<Node>>(*)(const std::shared_ptr<Node>&, const DimIdx_t, const std::size_t)>(getConvHorizontalTiling), + py::arg("node"), py::arg("axis"), py::arg("nb_slices")); + // m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter( // Recipie to remove a flatten operator. diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp index 85479d41f..d963b81d5 100644 --- a/python_binding/scheduler/pybind_Scheduler.cpp +++ b/python_binding/scheduler/pybind_Scheduler.cpp @@ -21,6 +21,7 @@ void init_Scheduler(py::module& m){ .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("verbose")=false) .def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name")) + .def("resetScheduling", &SequentialScheduler::resetScheduling) .def("generate_scheduling", &SequentialScheduler::generateScheduling, py::arg("verbose")=false) .def("get_static_scheduling", &SequentialScheduler::getStaticScheduling) ; -- GitLab