Skip to content
Snippets Groups Projects

Fuse bn

Merged Cyril Moineau requested to merge fuseBN into main
4 unresolved threads
1 file
+ 8
18
Compare changes
  • Side-by-side
  • Inline
@@ -34,6 +34,11 @@ void drawProgressBar(double progress, int barWidth, const std::string& additiona
}
void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
// TODO: For loop on the list of node to run
// run sequencially every runnable consumers once
// TODO: handle memory allocation in scheduler
// TODO: optimize memory usage
// setup initial producers list
mComputationNumber = 0;
std::set<std::shared_ptr<Node>> producers;
@@ -180,35 +185,20 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
mScheduling.clear();
this->generateScheduling();
// TODO: For loop on the list of node to run
// run sequencially every runnable consumers once
// TODO: handle memory allocation in scheduler
// TODO: optimize memory usage
int cpt = 0;
for (const auto& runnable : mStaticSchedule) {
bool computationOverForConsumer = true;
for (IOIndex_t parentIDi = 0; parentIDi < runnable->nbInputs(); ++parentIDi) {
if (runnable->getOperator()->getNbConsumedData(parentIDi) <
runnable->getOperator()->getNbRequiredData(parentIDi)) {
computationOverForConsumer = false;
break;
}
}
if (computationOverForConsumer) {
computationOver.insert(runnable);
}
if (verbose)
printf("run: %s\n",
(runnable->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))).c_str());
else
drawProgressBar(static_cast<float>(computationOver.size()) / static_cast<float>(mComputationNumber), 50,
drawProgressBar(static_cast<float>(cpt) / static_cast<float>(mStaticSchedule.size()), 50,
(std::string("running ") + runnable->type() + "_" +
std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))));
const auto tStart = std::chrono::high_resolution_clock::now();
runnable->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
mScheduling.push_back(SchedulingElement(runnable, tStart, tEnd));
cpt++;
}
if (!verbose) drawProgressBar(1.0, 50, " ");
printf("\n");
Loading