From 6ec7a51845001978424af65c3df1e083d7303e90 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Fri, 3 May 2024 09:15:22 +0000 Subject: [PATCH] Adapt scheduler to the new way the loss work. --- .../aidge/scheduler/SequentialScheduler.hpp | 2 +- python_binding/scheduler/pybind_Scheduler.cpp | 2 +- src/scheduler/SequentialScheduler.cpp | 29 ++++++++++--------- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/include/aidge/scheduler/SequentialScheduler.hpp b/include/aidge/scheduler/SequentialScheduler.hpp index 720160125..35dafead6 100644 --- a/include/aidge/scheduler/SequentialScheduler.hpp +++ b/include/aidge/scheduler/SequentialScheduler.hpp @@ -54,7 +54,7 @@ public: /** * @brief Run the provided Computational Graph with a batch of data */ - void backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instantiateGrad = true); + void backward(); private: SchedulingPolicy mSchedulingPolicy; diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp index 3f763c8ff..4bb7d50fa 100644 --- a/python_binding/scheduler/pybind_Scheduler.cpp +++ b/python_binding/scheduler/pybind_Scheduler.cpp @@ -34,7 +34,7 @@ void init_Scheduler(py::module& m){ py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>, Scheduler>(m, "SequentialScheduler") .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("data")=std::vector<Tensor>()) - .def("backward", &SequentialScheduler::backward, py::arg("data"), py::arg("instanciate_grad")=true) + .def("backward", &SequentialScheduler::backward) ; py::class_<ParallelScheduler, std::shared_ptr<ParallelScheduler>, Scheduler>(m, "ParallelScheduler") diff --git a/src/scheduler/SequentialScheduler.cpp b/src/scheduler/SequentialScheduler.cpp index f044603fb..cbd2f173d 100644 --- a/src/scheduler/SequentialScheduler.cpp +++ b/src/scheduler/SequentialScheduler.cpp @@ -73,21 +73,22 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std } } -void Aidge::SequentialScheduler::backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instanciateGrad) { +void Aidge::SequentialScheduler::backward() { // create ad set Grad values - if (instanciateGrad) { compile_gradient(mGraphView); } - - const auto& ordered_outputs = mGraphView->getOrderedOutputs(); - AIDGE_ASSERT(ordered_outputs.size() == data.size(), "You must provide the \ - right number of data objects to run the backward function. \ - {} outputs detected for the current GraphView when {} were \ - provided.", ordered_outputs.size(), data.size()); - for (std::size_t i = 0; i < ordered_outputs.size(); ++i) { - const std::shared_ptr<OperatorTensor> op_ = std::dynamic_pointer_cast<OperatorTensor>(ordered_outputs[i].first->getOperator()); - const std::shared_ptr<Tensor> t_grad = op_->getOutput(ordered_outputs[i].second)->grad(); - AIDGE_ASSERT(data[i]->dims() == t_grad->dims(), "Wrong gradient size."); - *t_grad = data[i]->clone(); - } + // if (instanciateGrad) { compile_gradient(mGraphView); } + + // const auto& ordered_outputs = mGraphView->getOrderedOutputs(); + // AIDGE_ASSERT(ordered_outputs.size() == data.size(), "You must provide the \ + // right number of data objects to run the backward function. \ + // {} outputs detected for the current GraphView when {} were \ + // provided.", ordered_outputs.size(), data.size()); + // for (std::size_t i = 0; i < ordered_outputs.size(); ++i) { + // const std::shared_ptr<OperatorTensor> op_ = std::dynamic_pointer_cast<OperatorTensor>(ordered_outputs[i].first->getOperator()); + // const std::shared_ptr<Tensor> t_grad = op_->getOutput(ordered_outputs[i].second)->grad(); + // AIDGE_ASSERT(data[i]->dims() == t_grad->dims(), "Wrong gradient size (expected {}, got {}).", t_grad->dims(), data[i]->dims()); + // *t_grad = data[i]->clone(); + // } + // Generate scheduling *only if empty* // If scheduling was already generated (in one or several steps, i.e. one or // several successive call to generateScheduling()), do not generate it twice -- GitLab