Skip to content
Snippets Groups Projects
Forked from Eclipse Projects / aidge / aidge_core
1289 commits behind the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
SequentialScheduler.cpp 4.00 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include "aidge/scheduler/SequentialScheduler.hpp"

#include <chrono>
#include <memory>
#include <set>
#include <string>

#include <fmt/ranges.h>
#include <fmt/color.h>

#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/Memorize.hpp"
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/recipes/GraphViewHelper.hpp"

void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std::shared_ptr<Aidge::Tensor>>& data) {
    // Collect all data input of the graph (that are producers)
    if (!data.empty()){
        connectInputs(data);
    }

    // Forward dims (if allowed)
    if (forwardDims) {mGraphView->forwardDims(); }

    // Generate scheduling *only if empty*
    // If scheduling was already generated (in one or several steps, i.e. one or
    // several successive call to generateScheduling()), do not generate it twice
    if (mStaticSchedule.empty()) {
        this->generateScheduling();
    }

    // Sort static scheduling according to the policy
    std::vector<std::shared_ptr<StaticSchedulingElement>> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end());

    if (mSchedulingPolicy == SchedulingPolicy::AsSoonAsPossible) {
        std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
            [](const auto& lhs, const auto& rhs) { return (lhs->early < rhs->early); });
    }
    else if (mSchedulingPolicy == SchedulingPolicy::AsLateAsPossible) {
        std::stable_sort(staticSchedule.begin(), staticSchedule.end(),
            [](const auto& lhs, const auto& rhs) { return (lhs->late < rhs->late); });
    }

    const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");

    for (const auto& runnable : staticSchedule) {
        Log::debug("run: {}", namePtrTable.at(runnable->node));

        const auto tStart = std::chrono::high_resolution_clock::now();
        runnable->node->forward();
        const auto tEnd = std::chrono::high_resolution_clock::now();
        mScheduling.push_back(SchedulingElement(runnable->node, tStart, tEnd));
    }

    ++mStaticScheduleStep;
    if (mStaticScheduleStep == mStaticSchedule.size()) {
        mStaticScheduleStep = 0;
    }
}

void Aidge::SequentialScheduler::backward(bool instanciateGrad) {
    // create ad set Grad values
    if (instanciateGrad) { compile_gradient(mGraphView); }

    // TODO: Check output grad are not empty

    // Generate scheduling *only if empty*
    // If scheduling was already generated (in one or several steps, i.e. one or
    // several successive call to generateScheduling()), do not generate it twice
    if (mStaticSchedule.empty()) {
        this->generateScheduling();
    }

    // map of node <-> info to display with verbose
    const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");

    // run scheduled operators in reverse order
    const auto& runnableList = mStaticSchedule.at(mStaticScheduleStep);
    for (auto runnable = runnableList.crbegin(); runnable != runnableList.crend(); ++runnable) {
        Log::debug("run: {}", namePtrTable.at((*runnable)->node));

        const auto tStart = std::chrono::high_resolution_clock::now();
        (*runnable)->node->backward();
        const auto tEnd = std::chrono::high_resolution_clock::now();
        mScheduling.push_back(SchedulingElement((*runnable)->node, tStart, tEnd));
    }

    ++mStaticScheduleStep;
    if (mStaticScheduleStep == mStaticSchedule.size()) {
        mStaticScheduleStep = 0;
    }
}