-
Olivier BICHLER authoredOlivier BICHLER authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
SequentialScheduler.cpp 4.01 KiB
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/scheduler/SequentialScheduler.hpp"
#include <chrono>
#include <memory>
#include <set>
#include <string>
#include <fmt/ranges.h>
#include <fmt/color.h>
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/Memorize.hpp"
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/recipes/GraphViewHelper.hpp"
void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std::shared_ptr<Aidge::Tensor>>& data) {
// Collect all data input of the graph (that are producers)
if (!data.empty()){
connectInputs(data);
}
// Forward dims (if allowed)
if (forwardDims) {mGraphView->forwardDims(); }
// Generate scheduling *only if empty*
// If scheduling was already generated (in one or several steps, i.e. one or
// several successive call to generateScheduling()), do not generate it twice
if (mStaticSchedule.empty()) {
this->generateScheduling();
}
// Sort static scheduling according to the policy
std::vector<StaticSchedulingElement*> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end());
if (mSchedulingPolicy == SchedulingPolicy::AsSoonAsPossible) {
std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
[](const auto& lhs, const auto& rhs) { return (lhs->early < rhs->early); });
}
else if (mSchedulingPolicy == SchedulingPolicy::AsLateAsPossible) {
std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
[](const auto& lhs, const auto& rhs) { return (lhs->late < rhs->late); });
}
const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
for (const auto& runnable : staticSchedule) {
const bool skip = !isNodeCondValid(runnable->node);
Log::debug("run: {}{}", namePtrTable.at(runnable->node), (skip) ? " -- skipped" : "");
if (!skip) {
const auto tStart = std::chrono::high_resolution_clock::now();
runnable->node->forward();
const auto tEnd = std::chrono::high_resolution_clock::now();
mScheduling.push_back(SchedulingElement(runnable->node, tStart, tEnd));
}
}
++mStaticScheduleStep;
if (mStaticScheduleStep == mStaticSchedule.size()) {
mStaticScheduleStep = 0;
}
}
void Aidge::SequentialScheduler::backward() {
// TODO: Check output grad are not empty
// Generate scheduling *only if empty*
// If scheduling was already generated (in one or several steps, i.e. one or
// several successive call to generateScheduling()), do not generate it twice
if (mStaticSchedule.empty()) {
this->generateScheduling();
}
// map of node <-> info to display with verbose
const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
// run scheduled operators in reverse order
const auto& runnableList = mStaticSchedule.at(mStaticScheduleStep);
for (auto runnable = runnableList.crbegin(); runnable != runnableList.crend(); ++runnable) {
Log::debug("run: {}", namePtrTable.at((*runnable)->node));
const auto tStart = std::chrono::high_resolution_clock::now();
(*runnable)->node->backward();
const auto tEnd = std::chrono::high_resolution_clock::now();
mScheduling.push_back(SchedulingElement((*runnable)->node, tStart, tEnd));
}
++mStaticScheduleStep;
if (mStaticScheduleStep == mStaticSchedule.size()) {
mStaticScheduleStep = 0;
}
}