From 6ec7a51845001978424af65c3df1e083d7303e90 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Fri, 3 May 2024 09:15:22 +0000
Subject: [PATCH] Adapt scheduler to the new way the loss work.

---
 .../aidge/scheduler/SequentialScheduler.hpp   |  2 +-
 python_binding/scheduler/pybind_Scheduler.cpp |  2 +-
 src/scheduler/SequentialScheduler.cpp         | 29 ++++++++++---------
 3 files changed, 17 insertions(+), 16 deletions(-)

diff --git a/include/aidge/scheduler/SequentialScheduler.hpp b/include/aidge/scheduler/SequentialScheduler.hpp
index 720160125..35dafead6 100644
--- a/include/aidge/scheduler/SequentialScheduler.hpp
+++ b/include/aidge/scheduler/SequentialScheduler.hpp
@@ -54,7 +54,7 @@ public:
     /**
      * @brief Run the provided Computational Graph with a batch of data
      */
-    void backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instantiateGrad = true);
+    void backward();
 
 private:
     SchedulingPolicy mSchedulingPolicy;
diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp
index 3f763c8ff..4bb7d50fa 100644
--- a/python_binding/scheduler/pybind_Scheduler.cpp
+++ b/python_binding/scheduler/pybind_Scheduler.cpp
@@ -34,7 +34,7 @@ void init_Scheduler(py::module& m){
     py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>, Scheduler>(m, "SequentialScheduler")
     .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
     .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("data")=std::vector<Tensor>())
-    .def("backward", &SequentialScheduler::backward, py::arg("data"), py::arg("instanciate_grad")=true)
+    .def("backward", &SequentialScheduler::backward)
     ;
 
     py::class_<ParallelScheduler, std::shared_ptr<ParallelScheduler>, Scheduler>(m, "ParallelScheduler")
diff --git a/src/scheduler/SequentialScheduler.cpp b/src/scheduler/SequentialScheduler.cpp
index f044603fb..cbd2f173d 100644
--- a/src/scheduler/SequentialScheduler.cpp
+++ b/src/scheduler/SequentialScheduler.cpp
@@ -73,21 +73,22 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std
     }
 }
 
-void Aidge::SequentialScheduler::backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instanciateGrad) {
+void Aidge::SequentialScheduler::backward() {
     // create ad set Grad values
-    if (instanciateGrad) { compile_gradient(mGraphView); }
-
-    const auto& ordered_outputs = mGraphView->getOrderedOutputs();
-    AIDGE_ASSERT(ordered_outputs.size() == data.size(), "You must provide the \
-                   right number of data objects to run the backward function. \
-                   {} outputs detected for the current GraphView when {} were \
-                   provided.", ordered_outputs.size(), data.size());
-    for (std::size_t i = 0; i < ordered_outputs.size(); ++i) {
-        const std::shared_ptr<OperatorTensor> op_ = std::dynamic_pointer_cast<OperatorTensor>(ordered_outputs[i].first->getOperator());
-        const std::shared_ptr<Tensor> t_grad = op_->getOutput(ordered_outputs[i].second)->grad();
-        AIDGE_ASSERT(data[i]->dims() == t_grad->dims(), "Wrong gradient size.");
-        *t_grad = data[i]->clone();
-    }
+    // if (instanciateGrad) { compile_gradient(mGraphView); }
+
+    // const auto& ordered_outputs = mGraphView->getOrderedOutputs();
+    // AIDGE_ASSERT(ordered_outputs.size() == data.size(), "You must provide the \
+    //                right number of data objects to run the backward function. \
+    //                {} outputs detected for the current GraphView when {} were \
+    //                provided.", ordered_outputs.size(), data.size());
+    // for (std::size_t i = 0; i < ordered_outputs.size(); ++i) {
+    //     const std::shared_ptr<OperatorTensor> op_ = std::dynamic_pointer_cast<OperatorTensor>(ordered_outputs[i].first->getOperator());
+    //     const std::shared_ptr<Tensor> t_grad = op_->getOutput(ordered_outputs[i].second)->grad();
+    //     AIDGE_ASSERT(data[i]->dims() == t_grad->dims(), "Wrong gradient size (expected {}, got {}).", t_grad->dims(), data[i]->dims());
+    //     *t_grad = data[i]->clone();
+    // }
+
     // Generate scheduling *only if empty*
     // If scheduling was already generated (in one or several steps, i.e. one or
     // several successive call to generateScheduling()), do not generate it twice
-- 
GitLab