Skip to content
Snippets Groups Projects
Forked from Eclipse Projects / aidge / aidge_core
1953 commits behind the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Scheduler.hpp 4.77 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#ifndef AIDGE_SCHEDULER_H_
#define AIDGE_SCHEDULER_H_

#include <chrono>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include <map>

#include "aidge/utils/Types.h"

#include "aidge/data/Tensor.hpp"
#include "aidge/scheduler/MemoryManager.hpp"

namespace Aidge {
class Node;
class GraphView;

class SequentialScheduler {
private:
    struct StaticSchedulingElement {
        StaticSchedulingElement(
            std::shared_ptr<Node> node_,
            size_t early_,
            size_t late_)
            : node(node_), early(early_), late(late_) {}

        std::shared_ptr<Node> node;
        size_t early;
        size_t late;
    };

    struct SchedulingElement {
        SchedulingElement(
            std::shared_ptr<Node> node_,
            std::chrono::time_point<std::chrono::high_resolution_clock> start_,
            std::chrono::time_point<std::chrono::high_resolution_clock> end_)
            : node(node_), start(start_), end(end_) {}

        std::shared_ptr<Node> node;
        std::chrono::time_point<std::chrono::high_resolution_clock> start;
        std::chrono::time_point<std::chrono::high_resolution_clock> end;
    };

    struct PriorProducersConsumers {
        bool isPrior = false;
        std::set<std::shared_ptr<Aidge::Node>> requiredProducers;
        std::set<std::shared_ptr<Aidge::Node>> priorConsumers;
    };

public:
    SequentialScheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr)
        : mGraphView(graphView),
          mUpperNode(upperNode)
    {
        // ctor
    };
    ~SequentialScheduler() = default;

    void generateScheduling(bool verbose = false);
    std::vector<StaticSchedulingElement> generateEarlyLateScheduling() const;
    void resetScheduling();

    /**
     * Generate the memory layout for the current static scheduling.
     * @param incProducers If true, include the producers in the memory layout.
     * @param wrapAroundBuffer If true, allow wrapping in memory planes.
    */
    MemoryManager generateMemory(bool incProducers = false, bool wrapAroundBuffer = false) const;

    /**
     * @brief Place the data tensors inside in the data input tensor of the graphView. In case of multiple data input tensors, they are mapped to producers in the order given by the graph.
     * 
     * @param data data input tensors
     */
    void connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data);

    /**
     * @brief Run the provided Computational Graph with a batch of data
     */
    void forward(bool forwardDims = true, bool verbose = false, std::vector<std::shared_ptr<Aidge::Tensor>> data = {});

    /**
     * @brief Save in a Markdown file the order of layers execution.
     * @param fileName Name of the generated file.
     */
    void saveSchedulingDiagram(const std::string& fileName) const;
    
    void saveStaticSchedulingDiagram(const std::string& fileName, const std::vector<StaticSchedulingElement>& scheduling) const;

    /**
     * @brief Return a vector of Node ordered by the order they are called by the scheduler
     * @return std::vector<std::shared_ptr<Node>>
     */
    inline std::vector<std::shared_ptr<Node>> getStaticScheduling(size_t step = 0) const noexcept {
        return mStaticSchedule.at(step);
    }
    inline std::shared_ptr<GraphView> getGraphView() const noexcept {
        return mGraphView;
    }

private:
    /**
     * @brief Set of layers receiving an input from currently processing layers
     *
     * @param producers Set of layers ready to run.
     * @return std::set<std::shared_ptr<Node>>
     */
    std::set<std::shared_ptr<Node>> getConsumers(const std::set<std::shared_ptr<Node>>& producers) const;
    NbElts_t getNbAvailableData(const std::shared_ptr<Node>& node, const IOIndex_t inputIdx) const;
    PriorProducersConsumers getPriorProducersConsumers(const std::shared_ptr<Node>& node) const;

    /** @brief Shared ptr to the scheduled graph view */
    std::shared_ptr<GraphView> mGraphView;
    /** @brief Shared ptr to the upper node containing the graph view */
    std::weak_ptr<Node> mUpperNode;
    /** @brief List of SchedulingElement (i.e: Nodes with their computation time) */
    std::vector<SchedulingElement> mScheduling;
    /** @brief List of nodes ordered by their */
    std::vector<std::vector<std::shared_ptr<Node>>> mStaticSchedule;
    size_t mStaticScheduleStep = 0;
};
} // namespace Aidge

#endif /* AIDGE_SCHEDULER_H_ */