From e40ef693f9faaa2451fe711e73f9b28ab7367a7c Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Tue, 26 Mar 2024 15:42:05 +0000 Subject: [PATCH] Upd SequentialScheduler::backward() member function --- include/aidge/scheduler/Scheduler.hpp | 4 +-- src/scheduler/Scheduler.cpp | 41 ++++++++++++++++++++------- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 07e7b2e47..0406835a5 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -69,7 +69,7 @@ public: /** * @brief Place the data tensors inside in the data input tensor of the graphView. In case of multiple data input tensors, they are mapped to producers in the order given by the graph. - * + * * @param data data input tensors */ void connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data); @@ -82,7 +82,7 @@ public: /** * @brief Run the provided Computational Graph with a batch of data */ - void backward(bool forwardDims = true, bool verbose = false); + void backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instantiateGrad = true, bool verbose = false); /** * @brief Save in a Markdown file the order of layers execution. diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index d43266fbf..49e8de80c 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -476,10 +476,21 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::ve } } -void Aidge::SequentialScheduler::backward(bool instanciateGrad, bool verbose) { - // Forward dims (if allowed) - if (instanciateGrad) {compile_gradient(mGraphView); } - +void Aidge::SequentialScheduler::backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instanciateGrad, bool verbose) { + // create ad set Grad values + if (instanciateGrad) { compile_gradient(mGraphView); } + + const auto& ordered_outputs = mGraphView->getOrderedOutputs(); + AIDGE_ASSERT(ordered_outputs.size() == data.size(), "You must provide the \ + right number of data objects to run the backward function. \ + {} outputs detected for the current GraphView when {} were \ + provided.", ordered_outputs.size(), data.size()); + for (std::size_t i = 0; i < ordered_outputs.size(); ++i) { + const std::shared_ptr<OperatorTensor> op_ = std::dynamic_pointer_cast<OperatorTensor>(ordered_outputs[i].first->getOperator()); + const std::shared_ptr<Tensor> t_grad = op_->getOutput(ordered_outputs[i].second)->grad(); + AIDGE_ASSERT(data[i]->dims() == t_grad->dims(), "Wrong gradient size."); + *t_grad = data[i]->clone(); + } // Generate scheduling *only if empty* // If scheduling was already generated (in one or several steps, i.e. one or // several successive call to generateScheduling()), do not generate it twice @@ -487,18 +498,21 @@ void Aidge::SequentialScheduler::backward(bool instanciateGrad, bool verbose) { this->generateScheduling(); } + // map of node <-> info to display with verbose + const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); + // Clear previous scheduling results mScheduling.clear(); - int cpt = 0; - for (auto runnable = mStaticSchedule.crbegin(); runnable != mStaticSchedule.crend(); ++runnable) { + std::size_t cpt = 0; + // run scheduled operators in reverse order + const auto& runnableList = mStaticSchedule.at(mStaticScheduleStep); + for (auto runnable = runnableList.crbegin(); runnable != runnableList.crend(); ++runnable) { if (verbose) - printf("run: %s\n", - ((*runnable)->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable->get()))).c_str()); + fmt::print("run: {}\n", namePtrTable.at(*runnable)); else drawProgressBar(static_cast<float>(cpt) / static_cast<float>(mStaticSchedule.size()), 50, - (std::string("running ") + (*runnable)->type() + "_" + - std::to_string(reinterpret_cast<uintptr_t>(runnable->get())))); + (std::string("running ") + namePtrTable.at(*runnable))); const auto tStart = std::chrono::high_resolution_clock::now(); (*runnable)->backward(); const auto tEnd = std::chrono::high_resolution_clock::now(); @@ -506,7 +520,12 @@ void Aidge::SequentialScheduler::backward(bool instanciateGrad, bool verbose) { cpt++; } if (!verbose) drawProgressBar(1.0, 50, " "); - printf("\n"); + fmt::print("\n"); + + ++mStaticScheduleStep; + if (mStaticScheduleStep == mStaticSchedule.size()) { + mStaticScheduleStep = 0; + } } -- GitLab