From 51c1b1caca9858358e1fcf0fc0de6153caf109fe Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Thu, 28 Mar 2024 09:52:47 +0000
Subject: [PATCH] Add backward, grad and some node getters to python binding of
 scheduler, Tensor and GraphView

---
 python_binding/data/pybind_Tensor.cpp         | 1 +
 python_binding/graph/pybind_GraphView.cpp     | 2 ++
 python_binding/scheduler/pybind_Scheduler.cpp | 1 +
 3 files changed, 4 insertions(+)

diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp
index 93389edf6..b97af94ad 100644
--- a/python_binding/data/pybind_Tensor.cpp
+++ b/python_binding/data/pybind_Tensor.cpp
@@ -76,6 +76,7 @@ void init_Tensor(py::module& m){
     .def("set_datatype", &Tensor::setDataType, py::arg("datatype"), py::arg("copyCast") = true)
     .def("set_backend", &Tensor::setBackend, py::arg("name"), py::arg("device") = 0, py::arg("copyFrom") = true)
     .def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims)
+    .def("grad", &Tensor::grad)
     .def("dtype", &Tensor::dataType)
     .def("size", &Tensor::size)
     .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize)
diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp
index eae05d8e2..f06a70f32 100644
--- a/python_binding/graph/pybind_GraphView.cpp
+++ b/python_binding/graph/pybind_GraphView.cpp
@@ -31,6 +31,8 @@ void init_GraphView(py::module& m) {
           :type path: str
           )mydelimiter")
           .def("log_outputs", &GraphView::logOutputs, py::arg("path"))
+          .def("get_ordered_inputs", &GraphView::getOrderedInputs)
+          .def("get_ordered_outputs", &GraphView::getOrderedOutputs)
           .def("get_output_nodes", &GraphView::outputNodes,
           R"mydelimiter(
           Get set of output Nodes.
diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp
index 170aa6c27..1b541b606 100644
--- a/python_binding/scheduler/pybind_Scheduler.cpp
+++ b/python_binding/scheduler/pybind_Scheduler.cpp
@@ -21,6 +21,7 @@ void init_Scheduler(py::module& m){
     py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>>(m, "SequentialScheduler")
     .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
     .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("verbose")=false, py::arg("data")=std::vector<Tensor>())
+    .def("backward", &SequentialScheduler::backward, py::arg("data"), py::arg("instanciate_grad")=true, py::arg("verbose")=false)
     .def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name"))
     .def("resetScheduling", &SequentialScheduler::resetScheduling)
     .def("generate_scheduling", &SequentialScheduler::generateScheduling, py::arg("verbose")=false)
-- 
GitLab