From 7729a5a63395e8e1cbaf4c455b1aa9775de43868 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Mon, 6 May 2024 13:19:12 +0000
Subject: [PATCH] Add back instantiateGrad arg in Scheduler::backward().

---
 include/aidge/scheduler/SequentialScheduler.hpp | 2 +-
 python_binding/scheduler/pybind_Scheduler.cpp   | 2 +-
 src/scheduler/SequentialScheduler.cpp           | 7 ++++++-
 3 files changed, 8 insertions(+), 3 deletions(-)

diff --git a/include/aidge/scheduler/SequentialScheduler.hpp b/include/aidge/scheduler/SequentialScheduler.hpp
index 35dafead6..a7929fde8 100644
--- a/include/aidge/scheduler/SequentialScheduler.hpp
+++ b/include/aidge/scheduler/SequentialScheduler.hpp
@@ -54,7 +54,7 @@ public:
     /**
      * @brief Run the provided Computational Graph with a batch of data
      */
-    void backward();
+    void backward(bool instantiateGrad = true);
 
 private:
     SchedulingPolicy mSchedulingPolicy;
diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp
index 4bb7d50fa..b16134da3 100644
--- a/python_binding/scheduler/pybind_Scheduler.cpp
+++ b/python_binding/scheduler/pybind_Scheduler.cpp
@@ -34,7 +34,7 @@ void init_Scheduler(py::module& m){
     py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>, Scheduler>(m, "SequentialScheduler")
     .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
     .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("data")=std::vector<Tensor>())
-    .def("backward", &SequentialScheduler::backward)
+    .def("backward", &SequentialScheduler::backward, py::arg("instanciate_grad")=true)
     ;
 
     py::class_<ParallelScheduler, std::shared_ptr<ParallelScheduler>, Scheduler>(m, "ParallelScheduler")
diff --git a/src/scheduler/SequentialScheduler.cpp b/src/scheduler/SequentialScheduler.cpp
index 6e3df1bb3..74b1b3f0c 100644
--- a/src/scheduler/SequentialScheduler.cpp
+++ b/src/scheduler/SequentialScheduler.cpp
@@ -73,7 +73,12 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std
     }
 }
 
-void Aidge::SequentialScheduler::backward() {
+void Aidge::SequentialScheduler::backward(bool instanciateGrad) {
+    // create ad set Grad values
+    if (instanciateGrad) { compile_gradient(mGraphView); }
+
+    // TODO: Check output grad are not empty
+
     // Generate scheduling *only if empty*
     // If scheduling was already generated (in one or several steps, i.e. one or
     // several successive call to generateScheduling()), do not generate it twice
-- 
GitLab