diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 07e7b2e470423d8093ce3ea966334cb3e6727ea3..0406835a5810c06262a3fbb1a87a8c51dbfc91fe 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -69,7 +69,7 @@ public: /** * @brief Place the data tensors inside in the data input tensor of the graphView. In case of multiple data input tensors, they are mapped to producers in the order given by the graph. - * + * * @param data data input tensors */ void connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data); @@ -82,7 +82,7 @@ public: /** * @brief Run the provided Computational Graph with a batch of data */ - void backward(bool forwardDims = true, bool verbose = false); + void backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instantiateGrad = true, bool verbose = false); /** * @brief Save in a Markdown file the order of layers execution. diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index d43266fbfb49896e8c97112eba45234ecde95869..49e8de80cbca4e3b43720d921e261599b0db9bfa 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -476,10 +476,21 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::ve } } -void Aidge::SequentialScheduler::backward(bool instanciateGrad, bool verbose) { - // Forward dims (if allowed) - if (instanciateGrad) {compile_gradient(mGraphView); } - +void Aidge::SequentialScheduler::backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instanciateGrad, bool verbose) { + // create ad set Grad values + if (instanciateGrad) { compile_gradient(mGraphView); } + + const auto& ordered_outputs = mGraphView->getOrderedOutputs(); + AIDGE_ASSERT(ordered_outputs.size() == data.size(), "You must provide the \ + right number of data objects to run the backward function. \ + {} outputs detected for the current GraphView when {} were \ + provided.", ordered_outputs.size(), data.size()); + for (std::size_t i = 0; i < ordered_outputs.size(); ++i) { + const std::shared_ptr<OperatorTensor> op_ = std::dynamic_pointer_cast<OperatorTensor>(ordered_outputs[i].first->getOperator()); + const std::shared_ptr<Tensor> t_grad = op_->getOutput(ordered_outputs[i].second)->grad(); + AIDGE_ASSERT(data[i]->dims() == t_grad->dims(), "Wrong gradient size."); + *t_grad = data[i]->clone(); + } // Generate scheduling *only if empty* // If scheduling was already generated (in one or several steps, i.e. one or // several successive call to generateScheduling()), do not generate it twice @@ -487,18 +498,21 @@ void Aidge::SequentialScheduler::backward(bool instanciateGrad, bool verbose) { this->generateScheduling(); } + // map of node <-> info to display with verbose + const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); + // Clear previous scheduling results mScheduling.clear(); - int cpt = 0; - for (auto runnable = mStaticSchedule.crbegin(); runnable != mStaticSchedule.crend(); ++runnable) { + std::size_t cpt = 0; + // run scheduled operators in reverse order + const auto& runnableList = mStaticSchedule.at(mStaticScheduleStep); + for (auto runnable = runnableList.crbegin(); runnable != runnableList.crend(); ++runnable) { if (verbose) - printf("run: %s\n", - ((*runnable)->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable->get()))).c_str()); + fmt::print("run: {}\n", namePtrTable.at(*runnable)); else drawProgressBar(static_cast<float>(cpt) / static_cast<float>(mStaticSchedule.size()), 50, - (std::string("running ") + (*runnable)->type() + "_" + - std::to_string(reinterpret_cast<uintptr_t>(runnable->get())))); + (std::string("running ") + namePtrTable.at(*runnable))); const auto tStart = std::chrono::high_resolution_clock::now(); (*runnable)->backward(); const auto tEnd = std::chrono::high_resolution_clock::now(); @@ -506,7 +520,12 @@ void Aidge::SequentialScheduler::backward(bool instanciateGrad, bool verbose) { cpt++; } if (!verbose) drawProgressBar(1.0, 50, " "); - printf("\n"); + fmt::print("\n"); + + ++mStaticScheduleStep; + if (mStaticScheduleStep == mStaticSchedule.size()) { + mStaticScheduleStep = 0; + } }