Skip to content
Snippets Groups Projects
Commit 764c6ad8 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Add] 'backward()' function to SequentialScheduler

parent 95a2f4da
No related branches found
No related tags found
3 merge requests!105version 0.2.0,!88Basic supervised learning,!79Scheduler backward
Pipeline #38467 canceled
...@@ -55,6 +55,11 @@ public: ...@@ -55,6 +55,11 @@ public:
*/ */
void forward(bool forwardDims = true, bool verbose = false); void forward(bool forwardDims = true, bool verbose = false);
/**
* @brief Run the provided Computational Graph with a batch of data
*/
void backward(bool forwardDims = true, bool verbose = false);
/** /**
* @brief Save in a Markdown file the order of layers execution. * @brief Save in a Markdown file the order of layers execution.
* @param fileName Name of the generated file. * @param fileName Name of the generated file.
......
...@@ -18,8 +18,9 @@ ...@@ -18,8 +18,9 @@
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Types.h"
#include "aidge/recipies/GraphViewHelper.hpp"
void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") { void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") {
putchar('['); putchar('[');
...@@ -208,6 +209,40 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { ...@@ -208,6 +209,40 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
printf("\n"); printf("\n");
} }
void Aidge::SequentialScheduler::backward(bool instanciateGrad, bool verbose) {
// Forward dims (if allowed)
if (instanciateGrad) {instanciateGradient(mGraphView); }
// Generate scheduling *only if empty*
// If scheduling was already generated (in one or several steps, i.e. one or
// several successive call to generateScheduling()), do not generate it twice
if (mStaticSchedule.empty()) {
this->generateScheduling();
}
// Clear previous scheduling results
mScheduling.clear();
int cpt = 0;
for (auto runnable = mStaticSchedule.crbegin(); runnable != mStaticSchedule.crend(); ++runnable) {
if (verbose)
printf("run: %s\n",
((*runnable)->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable->get()))).c_str());
else
drawProgressBar(static_cast<float>(cpt) / static_cast<float>(mStaticSchedule.size()), 50,
(std::string("running ") + (*runnable)->type() + "_" +
std::to_string(reinterpret_cast<uintptr_t>(runnable->get()))));
const auto tStart = std::chrono::high_resolution_clock::now();
(*runnable)->backward();
const auto tEnd = std::chrono::high_resolution_clock::now();
mScheduling.push_back(SchedulingElement(*runnable, tStart, tEnd));
cpt++;
}
if (!verbose) drawProgressBar(1.0, 50, " ");
printf("\n");
}
void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const { void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const {
FILE* fp = std::fopen((fileName + ".mmd").c_str(), "w"); FILE* fp = std::fopen((fileName + ".mmd").c_str(), "w");
std::fprintf(fp, "gantt\ndateFormat x\naxisFormat %%Q ms\n\n"); std::fprintf(fp, "gantt\ndateFormat x\naxisFormat %%Q ms\n\n");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment