diff --git a/include/aidge/scheduler/SequentialScheduler.hpp b/include/aidge/scheduler/SequentialScheduler.hpp index 35dafead6dc424550df7d83d54f5ec998c3b4d86..a7929fde8a2affdd562d70d11a7c809aaf3357d0 100644 --- a/include/aidge/scheduler/SequentialScheduler.hpp +++ b/include/aidge/scheduler/SequentialScheduler.hpp @@ -54,7 +54,7 @@ public: /** * @brief Run the provided Computational Graph with a batch of data */ - void backward(); + void backward(bool instantiateGrad = true); private: SchedulingPolicy mSchedulingPolicy; diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp index 4bb7d50fa4f6d3da00dd176e3ff5be037017e3ad..b16134da324383a4542965393257288c49dceed0 100644 --- a/python_binding/scheduler/pybind_Scheduler.cpp +++ b/python_binding/scheduler/pybind_Scheduler.cpp @@ -34,7 +34,7 @@ void init_Scheduler(py::module& m){ py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>, Scheduler>(m, "SequentialScheduler") .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("data")=std::vector<Tensor>()) - .def("backward", &SequentialScheduler::backward) + .def("backward", &SequentialScheduler::backward, py::arg("instanciate_grad")=true) ; py::class_<ParallelScheduler, std::shared_ptr<ParallelScheduler>, Scheduler>(m, "ParallelScheduler") diff --git a/src/scheduler/SequentialScheduler.cpp b/src/scheduler/SequentialScheduler.cpp index 6e3df1bb38e4a4f7650326ce1c36fcdede7cacc9..74b1b3f0c6e9be164792460669821744661c15b3 100644 --- a/src/scheduler/SequentialScheduler.cpp +++ b/src/scheduler/SequentialScheduler.cpp @@ -73,7 +73,12 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std } } -void Aidge::SequentialScheduler::backward() { +void Aidge::SequentialScheduler::backward(bool instanciateGrad) { + // create ad set Grad values + if (instanciateGrad) { compile_gradient(mGraphView); } + + // TODO: Check output grad are not empty + // Generate scheduling *only if empty* // If scheduling was already generated (in one or several steps, i.e. one or // several successive call to generateScheduling()), do not generate it twice