Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
SequentialScheduler.cpp 4.01 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include "aidge/scheduler/SequentialScheduler.hpp"

#include <chrono>
#include <memory>
#include <set>
#include <string>

#include <fmt/ranges.h>
#include <fmt/color.h>

#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/Memorize.hpp"
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/recipes/GraphViewHelper.hpp"

void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std::shared_ptr<Aidge::Tensor>>& data) {
    // Collect all data input of the graph (that are producers)
    if (!data.empty()){
        connectInputs(data);
    }

    // Forward dims (if allowed)
    if (forwardDims) {mGraphView->forwardDims(); }

    // Generate scheduling *only if empty*
    // If scheduling was already generated (in one or several steps, i.e. one or
    // several successive call to generateScheduling()), do not generate it twice
    if (mStaticSchedule.empty()) {
        this->generateScheduling();
    }

    // Sort static scheduling according to the policy
    std::vector<StaticSchedulingElement*> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end());

    if (mSchedulingPolicy == SchedulingPolicy::AsSoonAsPossible) {
        std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
            [](const auto& lhs, const auto& rhs) { return (lhs->early < rhs->early); });
    }
    else if (mSchedulingPolicy == SchedulingPolicy::AsLateAsPossible) {
        std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
            [](const auto& lhs, const auto& rhs) { return (lhs->late < rhs->late); });
    }

    const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");

    for (const auto& runnable : staticSchedule) {
        const bool skip = !isNodeCondValid(runnable->node);
        Log::debug("run: {}{}", namePtrTable.at(runnable->node), (skip) ? " -- skipped" : "");

        if (!skip) {
            const auto tStart = std::chrono::high_resolution_clock::now();
            runnable->node->forward();
            const auto tEnd = std::chrono::high_resolution_clock::now();
            mScheduling.push_back(SchedulingElement(runnable->node, tStart, tEnd));
        }
    }

    ++mStaticScheduleStep;
    if (mStaticScheduleStep == mStaticSchedule.size()) {
        mStaticScheduleStep = 0;
    }
}

void Aidge::SequentialScheduler::backward() {
    // TODO: Check output grad are not empty

    // Generate scheduling *only if empty*
    // If scheduling was already generated (in one or several steps, i.e. one or
    // several successive call to generateScheduling()), do not generate it twice
    if (mStaticSchedule.empty()) {
        this->generateScheduling();
    }

    // map of node <-> info to display with verbose
    const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");

    // run scheduled operators in reverse order
    const auto& runnableList = mStaticSchedule.at(mStaticScheduleStep);
    for (auto runnable = runnableList.crbegin(); runnable != runnableList.crend(); ++runnable) {
        Log::debug("run: {}", namePtrTable.at((*runnable)->node));

        const auto tStart = std::chrono::high_resolution_clock::now();
        (*runnable)->node->backward();
        const auto tEnd = std::chrono::high_resolution_clock::now();
        mScheduling.push_back(SchedulingElement((*runnable)->node, tStart, tEnd));
    }

    ++mStaticScheduleStep;
    if (mStaticScheduleStep == mStaticSchedule.size()) {
        mStaticScheduleStep = 0;
    }
}