Skip to content
Snippets Groups Projects
Commit e40ef693 authored by Maxence Naud's avatar Maxence Naud
Browse files

Upd SequentialScheduler::backward() member function

parent de0f33b3
No related branches found
No related tags found
3 merge requests!105version 0.2.0,!88Basic supervised learning,!79Scheduler backward
...@@ -69,7 +69,7 @@ public: ...@@ -69,7 +69,7 @@ public:
/** /**
* @brief Place the data tensors inside in the data input tensor of the graphView. In case of multiple data input tensors, they are mapped to producers in the order given by the graph. * @brief Place the data tensors inside in the data input tensor of the graphView. In case of multiple data input tensors, they are mapped to producers in the order given by the graph.
* *
* @param data data input tensors * @param data data input tensors
*/ */
void connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data); void connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data);
...@@ -82,7 +82,7 @@ public: ...@@ -82,7 +82,7 @@ public:
/** /**
* @brief Run the provided Computational Graph with a batch of data * @brief Run the provided Computational Graph with a batch of data
*/ */
void backward(bool forwardDims = true, bool verbose = false); void backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instantiateGrad = true, bool verbose = false);
/** /**
* @brief Save in a Markdown file the order of layers execution. * @brief Save in a Markdown file the order of layers execution.
......
...@@ -476,10 +476,21 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::ve ...@@ -476,10 +476,21 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::ve
} }
} }
void Aidge::SequentialScheduler::backward(bool instanciateGrad, bool verbose) { void Aidge::SequentialScheduler::backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instanciateGrad, bool verbose) {
// Forward dims (if allowed) // create ad set Grad values
if (instanciateGrad) {compile_gradient(mGraphView); } if (instanciateGrad) { compile_gradient(mGraphView); }
const auto& ordered_outputs = mGraphView->getOrderedOutputs();
AIDGE_ASSERT(ordered_outputs.size() == data.size(), "You must provide the \
right number of data objects to run the backward function. \
{} outputs detected for the current GraphView when {} were \
provided.", ordered_outputs.size(), data.size());
for (std::size_t i = 0; i < ordered_outputs.size(); ++i) {
const std::shared_ptr<OperatorTensor> op_ = std::dynamic_pointer_cast<OperatorTensor>(ordered_outputs[i].first->getOperator());
const std::shared_ptr<Tensor> t_grad = op_->getOutput(ordered_outputs[i].second)->grad();
AIDGE_ASSERT(data[i]->dims() == t_grad->dims(), "Wrong gradient size.");
*t_grad = data[i]->clone();
}
// Generate scheduling *only if empty* // Generate scheduling *only if empty*
// If scheduling was already generated (in one or several steps, i.e. one or // If scheduling was already generated (in one or several steps, i.e. one or
// several successive call to generateScheduling()), do not generate it twice // several successive call to generateScheduling()), do not generate it twice
...@@ -487,18 +498,21 @@ void Aidge::SequentialScheduler::backward(bool instanciateGrad, bool verbose) { ...@@ -487,18 +498,21 @@ void Aidge::SequentialScheduler::backward(bool instanciateGrad, bool verbose) {
this->generateScheduling(); this->generateScheduling();
} }
// map of node <-> info to display with verbose
const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
// Clear previous scheduling results // Clear previous scheduling results
mScheduling.clear(); mScheduling.clear();
int cpt = 0; std::size_t cpt = 0;
for (auto runnable = mStaticSchedule.crbegin(); runnable != mStaticSchedule.crend(); ++runnable) { // run scheduled operators in reverse order
const auto& runnableList = mStaticSchedule.at(mStaticScheduleStep);
for (auto runnable = runnableList.crbegin(); runnable != runnableList.crend(); ++runnable) {
if (verbose) if (verbose)
printf("run: %s\n", fmt::print("run: {}\n", namePtrTable.at(*runnable));
((*runnable)->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable->get()))).c_str());
else else
drawProgressBar(static_cast<float>(cpt) / static_cast<float>(mStaticSchedule.size()), 50, drawProgressBar(static_cast<float>(cpt) / static_cast<float>(mStaticSchedule.size()), 50,
(std::string("running ") + (*runnable)->type() + "_" + (std::string("running ") + namePtrTable.at(*runnable)));
std::to_string(reinterpret_cast<uintptr_t>(runnable->get()))));
const auto tStart = std::chrono::high_resolution_clock::now(); const auto tStart = std::chrono::high_resolution_clock::now();
(*runnable)->backward(); (*runnable)->backward();
const auto tEnd = std::chrono::high_resolution_clock::now(); const auto tEnd = std::chrono::high_resolution_clock::now();
...@@ -506,7 +520,12 @@ void Aidge::SequentialScheduler::backward(bool instanciateGrad, bool verbose) { ...@@ -506,7 +520,12 @@ void Aidge::SequentialScheduler::backward(bool instanciateGrad, bool verbose) {
cpt++; cpt++;
} }
if (!verbose) drawProgressBar(1.0, 50, " "); if (!verbose) drawProgressBar(1.0, 50, " ");
printf("\n"); fmt::print("\n");
++mStaticScheduleStep;
if (mStaticScheduleStep == mStaticSchedule.size()) {
mStaticScheduleStep = 0;
}
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment