Skip to content
Snippets Groups Projects
Commit 7729a5a6 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Add back instantiateGrad arg in Scheduler::backward().

parent c0cc9273
No related branches found
No related tags found
No related merge requests found
...@@ -54,7 +54,7 @@ public: ...@@ -54,7 +54,7 @@ public:
/** /**
* @brief Run the provided Computational Graph with a batch of data * @brief Run the provided Computational Graph with a batch of data
*/ */
void backward(); void backward(bool instantiateGrad = true);
private: private:
SchedulingPolicy mSchedulingPolicy; SchedulingPolicy mSchedulingPolicy;
......
...@@ -34,7 +34,7 @@ void init_Scheduler(py::module& m){ ...@@ -34,7 +34,7 @@ void init_Scheduler(py::module& m){
py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>, Scheduler>(m, "SequentialScheduler") py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>, Scheduler>(m, "SequentialScheduler")
.def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
.def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("data")=std::vector<Tensor>()) .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("data")=std::vector<Tensor>())
.def("backward", &SequentialScheduler::backward) .def("backward", &SequentialScheduler::backward, py::arg("instanciate_grad")=true)
; ;
py::class_<ParallelScheduler, std::shared_ptr<ParallelScheduler>, Scheduler>(m, "ParallelScheduler") py::class_<ParallelScheduler, std::shared_ptr<ParallelScheduler>, Scheduler>(m, "ParallelScheduler")
......
...@@ -73,7 +73,12 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std ...@@ -73,7 +73,12 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std
} }
} }
void Aidge::SequentialScheduler::backward() { void Aidge::SequentialScheduler::backward(bool instanciateGrad) {
// create ad set Grad values
if (instanciateGrad) { compile_gradient(mGraphView); }
// TODO: Check output grad are not empty
// Generate scheduling *only if empty* // Generate scheduling *only if empty*
// If scheduling was already generated (in one or several steps, i.e. one or // If scheduling was already generated (in one or several steps, i.e. one or
// several successive call to generateScheduling()), do not generate it twice // several successive call to generateScheduling()), do not generate it twice
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment