Skip to content
Snippets Groups Projects
Commit 6ec7a518 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Adapt scheduler to the new way the loss work.

parent b18003b6
No related branches found
No related tags found
2 merge requests!1190.2.1,!118Update how loss function work
Pipeline #45096 failed
......@@ -54,7 +54,7 @@ public:
/**
* @brief Run the provided Computational Graph with a batch of data
*/
void backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instantiateGrad = true);
void backward();
private:
SchedulingPolicy mSchedulingPolicy;
......
......@@ -34,7 +34,7 @@ void init_Scheduler(py::module& m){
py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>, Scheduler>(m, "SequentialScheduler")
.def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
.def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("data")=std::vector<Tensor>())
.def("backward", &SequentialScheduler::backward, py::arg("data"), py::arg("instanciate_grad")=true)
.def("backward", &SequentialScheduler::backward)
;
py::class_<ParallelScheduler, std::shared_ptr<ParallelScheduler>, Scheduler>(m, "ParallelScheduler")
......
......@@ -73,21 +73,22 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std
}
}
void Aidge::SequentialScheduler::backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instanciateGrad) {
void Aidge::SequentialScheduler::backward() {
// create ad set Grad values
if (instanciateGrad) { compile_gradient(mGraphView); }
const auto& ordered_outputs = mGraphView->getOrderedOutputs();
AIDGE_ASSERT(ordered_outputs.size() == data.size(), "You must provide the \
right number of data objects to run the backward function. \
{} outputs detected for the current GraphView when {} were \
provided.", ordered_outputs.size(), data.size());
for (std::size_t i = 0; i < ordered_outputs.size(); ++i) {
const std::shared_ptr<OperatorTensor> op_ = std::dynamic_pointer_cast<OperatorTensor>(ordered_outputs[i].first->getOperator());
const std::shared_ptr<Tensor> t_grad = op_->getOutput(ordered_outputs[i].second)->grad();
AIDGE_ASSERT(data[i]->dims() == t_grad->dims(), "Wrong gradient size.");
*t_grad = data[i]->clone();
}
// if (instanciateGrad) { compile_gradient(mGraphView); }
// const auto& ordered_outputs = mGraphView->getOrderedOutputs();
// AIDGE_ASSERT(ordered_outputs.size() == data.size(), "You must provide the \
// right number of data objects to run the backward function. \
// {} outputs detected for the current GraphView when {} were \
// provided.", ordered_outputs.size(), data.size());
// for (std::size_t i = 0; i < ordered_outputs.size(); ++i) {
// const std::shared_ptr<OperatorTensor> op_ = std::dynamic_pointer_cast<OperatorTensor>(ordered_outputs[i].first->getOperator());
// const std::shared_ptr<Tensor> t_grad = op_->getOutput(ordered_outputs[i].second)->grad();
// AIDGE_ASSERT(data[i]->dims() == t_grad->dims(), "Wrong gradient size (expected {}, got {}).", t_grad->dims(), data[i]->dims());
// *t_grad = data[i]->clone();
// }
// Generate scheduling *only if empty*
// If scheduling was already generated (in one or several steps, i.e. one or
// several successive call to generateScheduling()), do not generate it twice
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment