diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 6dcec5aaa4fa80aefebd538a1728445051ca080e..769f4cfbd1cddf61ca5ef47b15725b80306e3b6b 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -55,6 +55,11 @@ public: */ void forward(bool forwardDims = true, bool verbose = false); + /** + * @brief Run the provided Computational Graph with a batch of data + */ + void backward(bool forwardDims = true, bool verbose = false); + /** * @brief Save in a Markdown file the order of layers execution. * @param fileName Name of the generated file. diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 3afbcd0442fd40214687751d50bfc98809bba840..074f3a98e30411679f268f22b31ddaab5bd44bba 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -18,8 +18,9 @@ #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Node.hpp" -#include "aidge/utils/Types.h" #include "aidge/operator/OperatorTensor.hpp" +#include "aidge/utils/Types.h" +#include "aidge/recipies/GraphViewHelper.hpp" void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") { putchar('['); @@ -208,6 +209,40 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { printf("\n"); } +void Aidge::SequentialScheduler::backward(bool instanciateGrad, bool verbose) { + // Forward dims (if allowed) + if (instanciateGrad) {instanciateGradient(mGraphView); } + + // Generate scheduling *only if empty* + // If scheduling was already generated (in one or several steps, i.e. one or + // several successive call to generateScheduling()), do not generate it twice + if (mStaticSchedule.empty()) { + this->generateScheduling(); + } + + // Clear previous scheduling results + mScheduling.clear(); + + int cpt = 0; + for (auto runnable = mStaticSchedule.crbegin(); runnable != mStaticSchedule.crend(); ++runnable) { + if (verbose) + printf("run: %s\n", + ((*runnable)->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable->get()))).c_str()); + else + drawProgressBar(static_cast<float>(cpt) / static_cast<float>(mStaticSchedule.size()), 50, + (std::string("running ") + (*runnable)->type() + "_" + + std::to_string(reinterpret_cast<uintptr_t>(runnable->get())))); + const auto tStart = std::chrono::high_resolution_clock::now(); + (*runnable)->backward(); + const auto tEnd = std::chrono::high_resolution_clock::now(); + mScheduling.push_back(SchedulingElement(*runnable, tStart, tEnd)); + cpt++; + } + if (!verbose) drawProgressBar(1.0, 50, " "); + printf("\n"); +} + + void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const { FILE* fp = std::fopen((fileName + ".mmd").c_str(), "w"); std::fprintf(fp, "gantt\ndateFormat x\naxisFormat %%Q ms\n\n");