diff --git a/aidge_core/dynamic_analysis.py b/aidge_core/dynamic_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..c8cdd6710e7c429de3e1ce0fc202e04e7a7a1cd1 --- /dev/null +++ b/aidge_core/dynamic_analysis.py @@ -0,0 +1,161 @@ +import matplotlib +import matplotlib.pyplot as plt +from functools import partial +import numpy as np +import aidge_core + +class DynamicAnalysis(aidge_core.DynamicAnalysis): + def log_nb_arithm_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_arithm_ops, filename, title, log_scale) + + def log_nb_logic_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_logic_ops, filename, title, log_scale) + + def log_nb_comp_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_comp_ops, filename, title, log_scale) + + def log_nb_nl_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_nl_ops, filename, title, log_scale) + + def log_nb_mac_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_mac_ops, filename, title, log_scale) + + def log_nb_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_ops, filename, title, log_scale) + + def log_nb_arithm_int_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_arithm_int_ops, filename, title, log_scale) + + def log_nb_arithm_fp_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_arithm_fp_ops, filename, title, log_scale) + + def log_nb_ops_by_type(self, filename, title=None, log_scale=False): + return self._log_callback([aidge_core.OperatorStats.get_nb_arithm_int_ops, + aidge_core.OperatorStats.get_nb_arithm_fp_ops, + aidge_core.OperatorStats.get_nb_logic_ops, + aidge_core.OperatorStats.get_nb_comp_ops, + aidge_core.OperatorStats.get_nb_nl_ops], filename, title, log_scale) + + def _log_callback(self, callback, filename, title=None, log_scale=False): + """ + Log a statistic given by an OperatorStats callback member function. + Usage: + + stats = DynamicAnalysis(model) + stats.log_callback(aidge_core.OperatorStats.get_nb_params, "stats.png", "Nb params per operator") + + :param func: OperatorStats member function to call. + :param filename: Output graph file name. + :type filename: str + :param title: Title of the graph. + :type title: str + """ + + namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})"); + nodes = self.get_graph().get_ordered_nodes() + series = [] + legend = None + + for node in nodes: + if node.type() == "Producer": + continue + + stats = self.get_op_stats(node) + name = namePtrTable[node] + attr = {} + if type(node.get_operator()) is aidge_core.GenericOperatorOp: + # Display Generic Op in orange + attr = {'color': 'orange'} + elif not node.get_operator().is_atomic(): + # Display Meta Op in bold + attr = {'fontweight': 'bold'} + elif node.type() not in aidge_core.get_keys_OperatorStats(): + # Display unsupported operator in red labels + attr = {'color': 'red'} + if attr: + name = (name, attr) + if isinstance(callback, list): + series.append([name, [partial(cb, stats)() for cb in callback]]) + legend = [cb.__name__ for cb in callback] + if title is None: title = str(legend) + else: + series.append([name, partial(callback, stats)()]) + if title is None: title = callback.__name__ + + if title is None: title = str(callback) + if filename is not None: + self._log_bar(series, filename, title, legend, log_scale) + return series + + def _log_bar(self, series, filename, title=None, legend=None, log_scale=False): + names, values = zip(*series) + names_only = [item[0] if isinstance(item, tuple) else item for item in names] + fig, ax = plt.subplots(figsize=(max(5, len(names)/4), 5)) + plt.xlim(-0.5, len(names) - 0.5) + if isinstance(values[0], list): + series = [list(i) for i in zip(*values)] + bot = np.zeros(len(series[0])) + for i, serie in enumerate(series): + plt.bar(names_only, serie, bottom=bot) + bot += serie + else: + plt.bar(names_only, values) + if callable(getattr(ax.yaxis, 'minorticks_on', None)): + ax.yaxis.minorticks_on() # introduced in matplotlib 3.9.x + plt.grid(axis='y', which='major', linestyle='--', color='gray') + plt.grid(axis='y', which='minor', linestyle=':', color='lightgray') + formatter0 = matplotlib.ticker.EngFormatter(unit='') + ax.yaxis.set_major_formatter(formatter0) + plt.gca().set_axisbelow(True) + + labels = plt.gca().get_xticks() + tick_labels = plt.gca().get_xticklabels() + for i, label in enumerate(labels): + if isinstance(names[i], tuple): + if 'color' in names[i][1]: + tick_labels[i].set_color(names[i][1]['color']) + elif 'fontweight' in names[i][1]: + tick_labels[i].set_fontweight(names[i][1]['fontweight']) + + plt.xticks(rotation='vertical') + if log_scale: plt.yscale('log') + if title is not None: plt.title(title) + if legend is not None: plt.legend(legend) + plt.savefig(filename, bbox_inches='tight') + + def _log_barh(self, series, filename, title=None, legend=None, log_scale=False): + names, values = zip(*series) + names_only = [item[0] if isinstance(item, tuple) else item for item in names] + fig, ax = plt.subplots(figsize=(10, max(5, len(names)/4))) + plt.ylim(-0.5, len(names) - 0.5) + if isinstance(values[0], list): + series = [list(i) for i in zip(*values)] + left = np.zeros(len(series[0])) + for i, serie in enumerate(series): + plt.barh(names_only, serie, left=left) + left += serie + else: + plt.barh(names_only, values) + if callable(getattr(ax.xaxis, 'minorticks_on', None)): + ax.xaxis.minorticks_on() # introduced in matplotlib 3.9.x + plt.grid(axis='x', which='major', linestyle='--', color='gray') + plt.grid(axis='x', which='minor', linestyle=':', color='lightgray') + formatter0 = matplotlib.ticker.EngFormatter(unit='') + ax.xaxis.set_major_formatter(formatter0) + plt.gca().set_axisbelow(True) + plt.gca().xaxis.set_label_position('top') + plt.gca().xaxis.tick_top() + + labels = plt.gca().get_yticks() + tick_labels = plt.gca().get_yticklabels() + for i, label in enumerate(labels): + if isinstance(names[i], tuple): + if 'color' in names[i][1]: + tick_labels[i].set_color(names[i][1]['color']) + elif 'fontweight' in names[i][1]: + tick_labels[i].set_fontweight(names[i][1]['fontweight']) + + if log_scale: plt.xscale('log') + if title is not None: plt.title(title) + if legend is not None: plt.legend(legend) + plt.savefig(filename, bbox_inches='tight') diff --git a/aidge_core/export_utils/scheduler_export.py b/aidge_core/export_utils/scheduler_export.py index 0995e4cea50e1a813f433ce4a7a9e94733a68196..8aaedc18d8622e243f237785fd9d3b7f907d65fd 100644 --- a/aidge_core/export_utils/scheduler_export.py +++ b/aidge_core/export_utils/scheduler_export.py @@ -42,7 +42,7 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = :param scheduler: Scheduler instance managing the computation graph. - Uses `graph_view` and `get_static_scheduling` methods + Uses `graph_view` and `get_sequential_static_scheduling` methods to retrieve the computation graph layout and ordered nodes. :type scheduler: aidge_core.Scheduler :param export_folder_path: Path to the folder where the generated export files will be saved. @@ -88,7 +88,7 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = outputs_size: List[int] = [] # List of aidge_core.Node ordered by scheduler - list_forward_nodes: List[aidge_core.Node] = scheduler.get_static_scheduling() + list_forward_nodes: List[aidge_core.Node] = scheduler.get_sequential_static_scheduling() # If exportLib define use it # else parse component in platform diff --git a/aidge_core/mem_info.py b/aidge_core/mem_info.py index cabc2c72ee973babdf0342ba82057f7ab0769b52..b8d3c61016b1e2fbbd8304a0d082c7f7271fd829 100644 --- a/aidge_core/mem_info.py +++ b/aidge_core/mem_info.py @@ -22,7 +22,7 @@ def compute_default_mem_info(scheduler: aidge_core.Scheduler) -> Tuple[int, List mem_size = 0 # Exclude Producers and the last layers (because the results are stored outside the export) - for i, node in enumerate(scheduler.get_static_scheduling()): + for i, node in enumerate(scheduler.get_sequential_static_scheduling()): if node.type() != "Producer": node_mem_info = [] for out_id in range(node.get_nb_outputs()): @@ -161,7 +161,7 @@ def generate_optimized_memory_info(scheduler: aidge_core.Scheduler, stats_folder mem_planes = mem_manager.get_planes() - for node in scheduler.get_static_scheduling(): + for node in scheduler.get_sequential_static_scheduling(): node_mem_info = [] if node.type() == "Producer": pass diff --git a/aidge_core/static_analysis.py b/aidge_core/static_analysis.py index b4a82a4fbd9ef5205ce39dc5a519f44305bc455d..907bc48f5c104036ef538dfd4d98ecb8163ccb51 100644 --- a/aidge_core/static_analysis.py +++ b/aidge_core/static_analysis.py @@ -4,7 +4,7 @@ from functools import partial import numpy as np import aidge_core -class StaticAnalysisExt(aidge_core.StaticAnalysis): +class StaticAnalysis(aidge_core.StaticAnalysis): def log_nb_params(self, filename, title=None, log_scale=False): namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})"); nodes = self.get_graph().get_ordered_nodes() @@ -77,7 +77,7 @@ class StaticAnalysisExt(aidge_core.StaticAnalysis): Log a statistic given by an OperatorStats callback member function. Usage: - stats = StaticAnalysisExt(model) + stats = StaticAnalysis(model) stats.log_callback(aidge_core.OperatorStats.get_nb_params, "stats.png", "Nb params per operator") :param func: OperatorStats member function to call. diff --git a/include/aidge/analysis/DynamicAnalysis.hpp b/include/aidge/analysis/DynamicAnalysis.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3dadf79b362dad11ed7e924a29a5c6f8d5fec17d --- /dev/null +++ b/include/aidge/analysis/DynamicAnalysis.hpp @@ -0,0 +1,55 @@ + +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_ANALYSIS_DYNAMICANALYSIS_H_ +#define AIDGE_CORE_ANALYSIS_DYNAMICANALYSIS_H_ + +#include <cstddef> // std::size_t +#include <memory> +#include <string> + +#include "aidge/analysis/OperatorStats.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/utils/Registrar.hpp" + +namespace Aidge { +/** + * @brief Base class to compute statistics from a scheduled graph + * + */ +class DynamicAnalysis : public std::enable_shared_from_this<DynamicAnalysis> { +public: + DynamicAnalysis() = delete; + DynamicAnalysis(const Scheduler& scheduler); + + virtual ~DynamicAnalysis(); + + std::size_t getNbArithmOps() const; + std::size_t getNbLogicOps() const; + std::size_t getNbCompOps() const; + std::size_t getNbNLOps() const; + std::size_t getNbOps() const; + std::size_t getNbArithmIntOps() const; + std::size_t getNbArithmFpOps() const; + std::size_t getNbMACOps() const; + +protected: + const Scheduler& mScheduler; + + std::size_t accumulate(std::size_t (OperatorStats::*func)() const) const; +}; +} + +#endif /* AIDGE_CORE_ANALYSIS_DYNAMICANALYSIS_H_ */ diff --git a/include/aidge/graph/StaticAnalysis.hpp b/include/aidge/analysis/OperatorStats.hpp similarity index 89% rename from include/aidge/graph/StaticAnalysis.hpp rename to include/aidge/analysis/OperatorStats.hpp index cc5532224ebd00f17aefbf5c2620a3ef15cfaa2a..ac1abcee7af3809f5b1f68a1733d16aa5f29ba81 100644 --- a/include/aidge/graph/StaticAnalysis.hpp +++ b/include/aidge/analysis/OperatorStats.hpp @@ -10,8 +10,8 @@ * ********************************************************************************/ -#ifndef AIDGE_CORE_GRAPH_STATICANALYSIS_H_ -#define AIDGE_CORE_GRAPH_STATICANALYSIS_H_ +#ifndef AIDGE_CORE_ANALYSIS_OPERATORSTATS_H_ +#define AIDGE_CORE_ANALYSIS_OPERATORSTATS_H_ #include <cstddef> // std::size_t #include <memory> @@ -44,6 +44,14 @@ public: OperatorStats() = delete; OperatorStats(const Operator& op); + /** + * @brief Get the Operator Stats object corresponding to the given node. + * + * @param node Node + * @return std::shared_ptr<OperatorStats> Node's Operator stats + */ + static std::shared_ptr<OperatorStats> getOpStats(std::shared_ptr<Node> node); + virtual ~OperatorStats(); inline const Operator& getOperator() const noexcept { return mOp; } @@ -156,73 +164,6 @@ protected: const Operator &mOp; }; -/** - * @brief Base class to compute statistics from a GraphView - * - */ -class StaticAnalysis : public std::enable_shared_from_this<StaticAnalysis> { -public: - StaticAnalysis() = delete; - StaticAnalysis(std::shared_ptr<GraphView> graph); - - virtual ~StaticAnalysis(); - - inline const std::shared_ptr<GraphView> getGraph() const noexcept { return mGraph; } - - /** - * @brief Get the Operator Stats object corresponding to the given node. - * - * @param node Node - * @return std::shared_ptr<OperatorStats> Node's Operator stats - */ - std::shared_ptr<OperatorStats> getOpStats(std::shared_ptr<Node> node) const; - - /** - * @brief Get the number of parameters associated to a node. This includes - * all Producers directly connected to the node's inputs as well as all - * internal Producers (in case of a meta operator). - * - * Note: this function does not check if parameters are shared between - * several nodes or not. This means that simply adding parameters count from - * several nodes may lead to a higher number of parameters than in reality - * if some of them are shared. - * - * @param node Node - * @return std::size_t Number of parameters - */ - virtual std::size_t getNbParams(std::shared_ptr<Node> node) const; - - /** - * @brief Get the total parameters memory size, in bits, associated to a node. - * This includes all Producers directly connected to the node's inputs as - * well as all internal Producers (in case of a meta operator). - * - * Note: this function does not check if parameters are shared between - * several nodes or not. This means that simply adding parameters size from - * several nodes may lead to a higher parameter size than in reality - * if some of them are shared. - * - * @param node Node - * @return std::size_t Total parameters memory, in bits - */ - virtual std::size_t getParamsSize(std::shared_ptr<Node> node) const; - - std::size_t getNbArithmOps() const; - std::size_t getNbLogicOps() const; - std::size_t getNbCompOps() const; - std::size_t getNbNLOps() const; - std::size_t getNbOps() const; - std::size_t getNbArithmIntOps() const; - std::size_t getNbArithmFpOps() const; - std::size_t getNbMACOps() const; - virtual void summary(bool incProducers = false) const; - -protected: - const std::shared_ptr<GraphView> mGraph; - - std::size_t accumulate(std::size_t (OperatorStats::*func)() const) const; -}; - //////////////////////////////////////////////////////////////////////////////// class MetaOpStats : public OperatorStats { @@ -579,4 +520,4 @@ REGISTRAR(OperatorStats, "Tanh", ElemWiseNLOpStats::create); REGISTRAR(OperatorStats, "Pow", ElemWiseNLOpStats::create); } -#endif /* AIDGE_CORE_GRAPH_STATICANALYSIS_H_ */ +#endif /* AIDGE_CORE_ANALYSIS_OPERATORSTATS_H_ */ diff --git a/include/aidge/analysis/StaticAnalysis.hpp b/include/aidge/analysis/StaticAnalysis.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a0feadd72eae9d4ae8a31fb393c727cf790d77a9 --- /dev/null +++ b/include/aidge/analysis/StaticAnalysis.hpp @@ -0,0 +1,87 @@ + +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_ANALYSIS_STATICANALYSIS_H_ +#define AIDGE_CORE_ANALYSIS_STATICANALYSIS_H_ + +#include <cstddef> // std::size_t +#include <memory> +#include <string> + +#include "aidge/analysis/OperatorStats.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/utils/Registrar.hpp" + +namespace Aidge { +/** + * @brief Base class to compute statistics from a GraphView + * + */ +class StaticAnalysis : public std::enable_shared_from_this<StaticAnalysis> { +public: + StaticAnalysis() = delete; + StaticAnalysis(std::shared_ptr<GraphView> graph); + + virtual ~StaticAnalysis(); + + inline const std::shared_ptr<GraphView> getGraph() const noexcept { return mGraph; } + + /** + * @brief Get the number of parameters associated to a node. This includes + * all Producers directly connected to the node's inputs as well as all + * internal Producers (in case of a meta operator). + * + * Note: this function does not check if parameters are shared between + * several nodes or not. This means that simply adding parameters count from + * several nodes may lead to a higher number of parameters than in reality + * if some of them are shared. + * + * @param node Node + * @return std::size_t Number of parameters + */ + virtual std::size_t getNbParams(std::shared_ptr<Node> node) const; + + /** + * @brief Get the total parameters memory size, in bits, associated to a node. + * This includes all Producers directly connected to the node's inputs as + * well as all internal Producers (in case of a meta operator). + * + * Note: this function does not check if parameters are shared between + * several nodes or not. This means that simply adding parameters size from + * several nodes may lead to a higher parameter size than in reality + * if some of them are shared. + * + * @param node Node + * @return std::size_t Total parameters memory, in bits + */ + virtual std::size_t getParamsSize(std::shared_ptr<Node> node) const; + + std::size_t getNbArithmOps() const; + std::size_t getNbLogicOps() const; + std::size_t getNbCompOps() const; + std::size_t getNbNLOps() const; + std::size_t getNbOps() const; + std::size_t getNbArithmIntOps() const; + std::size_t getNbArithmFpOps() const; + std::size_t getNbMACOps() const; + virtual void summary(bool incProducers = false) const; + +protected: + const std::shared_ptr<GraphView> mGraph; + + std::size_t accumulate(std::size_t (OperatorStats::*func)() const) const; +}; +} + +#endif /* AIDGE_CORE_ANALYSIS_STATICANALYSIS_H_ */ diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 61aeb49ba1b03153e27257f5029ad6ade14bbf75..7c309783daaf1a7a4d3bceef24e80e46e6f2e3ba 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -46,7 +46,7 @@ class GraphView; * @see MemoryManager */ class Scheduler { -protected: +public: /** * @struct StaticSchedulingElement * @brief Represents a node in the static schedule. @@ -81,7 +81,7 @@ protected: std::chrono::time_point<std::chrono::high_resolution_clock> start; /** Actual start time of execution */ std::chrono::time_point<std::chrono::high_resolution_clock> end; /** Actual end time of execution */ }; -public: + enum class AvailableDataStatus { Connected, UpperNodeInputFound, @@ -90,7 +90,7 @@ public: NotConnected }; - enum class EarlyLateSort { + enum class SchedulingPolicy { Default, AsSoonAsPossible, AsLateAsPossible @@ -136,12 +136,28 @@ public: void tagConditionalNodes() const; /** - * @brief Get the static scheduling order of nodes. + * @brief Get the static scheduling (after generate scheduling). + * @return Vector of StaticSchedulingElement pointers. + */ + std::vector<StaticSchedulingElement*> getStaticScheduling(std::size_t step = 0) const { + return mStaticSchedule.at(step); + } + + /** + * @brief Get the static scheduling sequential order of nodes. * @param step The step of the static schedule to retrieve (default is 0). - * @param sorting Sorting mode. + * @param policy Sorting mode. * @return Vector of shared pointers to Nodes in their scheduled order. */ - std::vector<std::shared_ptr<Node>> getStaticScheduling(std::size_t step = 0, EarlyLateSort sorting = EarlyLateSort::Default) const; + std::vector<std::shared_ptr<Node>> getSequentialStaticScheduling(std::size_t step = 0, SchedulingPolicy policy = SchedulingPolicy::Default) const; + + /** + * @brief Get the dynamic scheduling (after graph execution). + * @return Vector of SchedulingElement. + */ + std::vector<SchedulingElement> getScheduling() const { + return mScheduling; + } /** * @brief Get the GraphView associated with this Scheduler. @@ -199,14 +215,14 @@ public: * order of execution for the nodes, to a file in Mermaid format. * @param fileName Name of the file to save the diagram (without extension). */ - void saveStaticSchedulingDiagram(const std::string& fileName) const; + void saveStaticSchedulingDiagram(const std::string& fileName, bool ignoreProducers = false) const; void saveFactorizedStaticSchedulingDiagram(const std::string& fileName, size_t minRepeat = 2) const; /** * @brief Save in a Mermaid file the order of layers execution. * @param fileName Name of the generated file. */ - void saveSchedulingDiagram(const std::string& fileName) const; + void saveSchedulingDiagram(const std::string& fileName, bool ignoreProducers = false) const; protected: diff --git a/include/aidge/scheduler/SequentialScheduler.hpp b/include/aidge/scheduler/SequentialScheduler.hpp index 35dafead6dc424550df7d83d54f5ec998c3b4d86..0ae18b085ba8b07bf4f50ad3fdd8d969572543b5 100644 --- a/include/aidge/scheduler/SequentialScheduler.hpp +++ b/include/aidge/scheduler/SequentialScheduler.hpp @@ -25,13 +25,6 @@ namespace Aidge { * Multi-threaded parallel scheduler with dynamic scheduling. */ class SequentialScheduler : public Scheduler { -public: - enum class SchedulingPolicy { - Default, - AsSoonAsPossible, - AsLateAsPossible - }; - public: SequentialScheduler(std::shared_ptr<GraphView> graphView, std::shared_ptr<Node> upperNode = nullptr) : Scheduler(graphView, upperNode), diff --git a/python_binding/analysis/pybind_DynamicAnalysis.cpp b/python_binding/analysis/pybind_DynamicAnalysis.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3cd71f741cdc3203c2f3ce62bfd9989fd4d8a674 --- /dev/null +++ b/python_binding/analysis/pybind_DynamicAnalysis.cpp @@ -0,0 +1,39 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include "aidge/analysis/DynamicAnalysis.hpp" + +namespace py = pybind11; +namespace Aidge { + +class pyDynamicAnalysis: public DynamicAnalysis { +public: + using DynamicAnalysis::DynamicAnalysis; // Inherit constructors + +}; + +void init_DynamicAnalysis(py::module& m){ + py::class_<DynamicAnalysis, std::shared_ptr<DynamicAnalysis>, pyDynamicAnalysis>(m, "DynamicAnalysis", py::multiple_inheritance(), py::dynamic_attr()) + .def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph")) + .def("get_nb_arithm_ops", &DynamicAnalysis::getNbArithmOps) + .def("get_nb_logic_ops", &DynamicAnalysis::getNbLogicOps) + .def("get_nb_comp_ops", &DynamicAnalysis::getNbCompOps) + .def("get_nb_nl_ops", &DynamicAnalysis::getNbNLOps) + .def("get_nb_ops", &DynamicAnalysis::getNbOps) + .def("get_nb_arithm_int_ops", &DynamicAnalysis::getNbArithmIntOps) + .def("get_nb_arithm_fp_ops", &DynamicAnalysis::getNbArithmFpOps) + .def("get_nb_mac_ops", &DynamicAnalysis::getNbMACOps) + ; +} +} diff --git a/python_binding/graph/pybind_StaticAnalysis.cpp b/python_binding/analysis/pybind_OperatorStats.cpp similarity index 57% rename from python_binding/graph/pybind_StaticAnalysis.cpp rename to python_binding/analysis/pybind_OperatorStats.cpp index b7c704d722e81b36e8d4988a4503428918e16a5a..be2b79e672345eebd4b16863fad1ec3c36123e06 100644 --- a/python_binding/graph/pybind_StaticAnalysis.cpp +++ b/python_binding/analysis/pybind_OperatorStats.cpp @@ -12,7 +12,7 @@ #include <pybind11/pybind11.h> #include <pybind11/stl.h> -#include "aidge/graph/StaticAnalysis.hpp" +#include "aidge/analysis/OperatorStats.hpp" namespace py = pybind11; namespace Aidge { @@ -74,41 +74,10 @@ public: } }; -class pyStaticAnalysis: public StaticAnalysis { -public: - using StaticAnalysis::StaticAnalysis; // Inherit constructors - - size_t getNbParams(std::shared_ptr<Node> node) const override { - PYBIND11_OVERRIDE( - size_t, - StaticAnalysis, - getNbParams, - node - ); - } - - size_t getParamsSize(std::shared_ptr<Node> node) const override { - PYBIND11_OVERRIDE( - size_t, - StaticAnalysis, - getParamsSize, - node - ); - } - - void summary(bool incProducers) const override { - PYBIND11_OVERRIDE( - void, - StaticAnalysis, - summary, - incProducers - ); - } -}; - -void init_StaticAnalysis(py::module& m){ +void init_OperatorStats(py::module& m){ py::class_<OperatorStats, std::shared_ptr<OperatorStats>, pyOperatorStats>(m, "OperatorStats", py::multiple_inheritance(), py::dynamic_attr()) .def(py::init<const Operator&>(), py::arg("op")) + .def_static("get_op_stats", &OperatorStats::getOpStats, py::arg("node")) .def("get_operator", &OperatorStats::getOperator) .def("get_nb_arithm_ops", &OperatorStats::getNbArithmOps) .def("get_nb_logic_ops", &OperatorStats::getNbLogicOps) @@ -120,22 +89,5 @@ void init_StaticAnalysis(py::module& m){ .def("get_nb_mac_ops", &OperatorStats::getNbMACOps) ; declare_registrable<OperatorStats>(m, "OperatorStats"); - - py::class_<StaticAnalysis, std::shared_ptr<StaticAnalysis>, pyStaticAnalysis>(m, "StaticAnalysis", py::multiple_inheritance(), py::dynamic_attr()) - .def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph")) - .def("get_graph", &StaticAnalysis::getGraph) - .def("get_nb_params", &StaticAnalysis::getNbParams, py::arg("node")) - .def("get_params_size", &StaticAnalysis::getParamsSize, py::arg("node")) - .def("get_nb_arithm_ops", &StaticAnalysis::getNbArithmOps) - .def("get_nb_logic_ops", &StaticAnalysis::getNbLogicOps) - .def("get_nb_comp_ops", &StaticAnalysis::getNbCompOps) - .def("get_nb_nl_ops", &StaticAnalysis::getNbNLOps) - .def("get_nb_ops", &StaticAnalysis::getNbOps) - .def("get_nb_arithm_int_ops", &StaticAnalysis::getNbArithmIntOps) - .def("get_nb_arithm_fp_ops", &StaticAnalysis::getNbArithmFpOps) - .def("get_nb_mac_ops", &StaticAnalysis::getNbMACOps) - .def("summary", &StaticAnalysis::summary, py::arg("inc_producers") = false) - .def("get_op_stats", &StaticAnalysis::getOpStats, py::arg("node")) - ; } } diff --git a/python_binding/analysis/pybind_StaticAnalysis.cpp b/python_binding/analysis/pybind_StaticAnalysis.cpp new file mode 100644 index 0000000000000000000000000000000000000000..65ee8e8b07b0e84497b666d6686581d5cf50e0e2 --- /dev/null +++ b/python_binding/analysis/pybind_StaticAnalysis.cpp @@ -0,0 +1,69 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include "aidge/analysis/StaticAnalysis.hpp" + +namespace py = pybind11; +namespace Aidge { + +class pyStaticAnalysis: public StaticAnalysis { +public: + using StaticAnalysis::StaticAnalysis; // Inherit constructors + + size_t getNbParams(std::shared_ptr<Node> node) const override { + PYBIND11_OVERRIDE( + size_t, + StaticAnalysis, + getNbParams, + node + ); + } + + size_t getParamsSize(std::shared_ptr<Node> node) const override { + PYBIND11_OVERRIDE( + size_t, + StaticAnalysis, + getParamsSize, + node + ); + } + + void summary(bool incProducers) const override { + PYBIND11_OVERRIDE( + void, + StaticAnalysis, + summary, + incProducers + ); + } +}; + +void init_StaticAnalysis(py::module& m){ + py::class_<StaticAnalysis, std::shared_ptr<StaticAnalysis>, pyStaticAnalysis>(m, "StaticAnalysis", py::multiple_inheritance(), py::dynamic_attr()) + .def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph")) + .def("get_graph", &StaticAnalysis::getGraph) + .def("get_nb_params", &StaticAnalysis::getNbParams, py::arg("node")) + .def("get_params_size", &StaticAnalysis::getParamsSize, py::arg("node")) + .def("get_nb_arithm_ops", &StaticAnalysis::getNbArithmOps) + .def("get_nb_logic_ops", &StaticAnalysis::getNbLogicOps) + .def("get_nb_comp_ops", &StaticAnalysis::getNbCompOps) + .def("get_nb_nl_ops", &StaticAnalysis::getNbNLOps) + .def("get_nb_ops", &StaticAnalysis::getNbOps) + .def("get_nb_arithm_int_ops", &StaticAnalysis::getNbArithmIntOps) + .def("get_nb_arithm_fp_ops", &StaticAnalysis::getNbArithmFpOps) + .def("get_nb_mac_ops", &StaticAnalysis::getNbMACOps) + .def("summary", &StaticAnalysis::summary, py::arg("inc_producers") = false) + ; +} +} diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index b2aa93dc901e549057847b5e5571afc8f571a1f6..c7a7330b616100b703173d53fc9b7236fb87295e 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -32,7 +32,9 @@ void init_OperatorImpl(py::module&); void init_Log(py::module&); void init_Operator(py::module&); void init_OperatorTensor(py::module&); +void init_OperatorStats(py::module&); void init_StaticAnalysis(py::module&); +void init_DynamicAnalysis(py::module&); void init_Abs(py::module&); void init_Add(py::module&); @@ -136,7 +138,9 @@ void init_Aidge(py::module& m) { init_Log(m); init_Operator(m); init_OperatorTensor(m); + init_OperatorStats(m); init_StaticAnalysis(m); + init_DynamicAnalysis(m); init_Abs(m); init_Add(m); diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp index 34ed93520e09167e69b0217f4b099ad48e9fc9b8..582ba46786fe37209a0c2d4770d5f477fcd271bd 100644 --- a/python_binding/scheduler/pybind_Scheduler.cpp +++ b/python_binding/scheduler/pybind_Scheduler.cpp @@ -21,23 +21,39 @@ namespace py = pybind11; namespace Aidge { void init_Scheduler(py::module& m){ - py::enum_<Scheduler::EarlyLateSort>(m, "EarlyLateSort") - .value("Default", Scheduler::EarlyLateSort::Default) - .value("AsSoonAsPossible", Scheduler::EarlyLateSort::AsSoonAsPossible) - .value("AsLateAsPossible", Scheduler::EarlyLateSort::AsLateAsPossible) + py::class_<Scheduler::StaticSchedulingElement>(m, "StaticSchedulingElement") + .def_readonly("node", &Scheduler::StaticSchedulingElement::node) + .def_readonly("early", &Scheduler::StaticSchedulingElement::early) + .def_readonly("late", &Scheduler::StaticSchedulingElement::late) + .def_readonly("earlier_than", &Scheduler::StaticSchedulingElement::earlierThan) + .def_readonly("later_than", &Scheduler::StaticSchedulingElement::laterThan) + ; + + py::class_<Scheduler::SchedulingElement>(m, "SchedulingElement") + .def_readonly("node", &Scheduler::SchedulingElement::node) + .def_readonly("start", &Scheduler::SchedulingElement::start) + .def_readonly("end", &Scheduler::SchedulingElement::end) + ; + + py::enum_<Scheduler::SchedulingPolicy>(m, "SchedulingPolicy") + .value("Default", Scheduler::SchedulingPolicy::Default) + .value("AsSoonAsPossible", Scheduler::SchedulingPolicy::AsSoonAsPossible) + .value("AsLateAsPossible", Scheduler::SchedulingPolicy::AsLateAsPossible) .export_values(); py::class_<Scheduler, std::shared_ptr<Scheduler>>(m, "Scheduler") .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) .def("graph_view", &Scheduler::graphView) .def("tag_conditional_nodes", &Scheduler::tagConditionalNodes) - .def("save_scheduling_diagram", &Scheduler::saveSchedulingDiagram, py::arg("file_name")) - .def("save_static_scheduling_diagram", &Scheduler::saveStaticSchedulingDiagram, py::arg("file_name")) + .def("save_scheduling_diagram", &Scheduler::saveSchedulingDiagram, py::arg("file_name"), py::arg("ignore_producers") = false) + .def("save_static_scheduling_diagram", &Scheduler::saveStaticSchedulingDiagram, py::arg("file_name"), py::arg("ignore_producers") = false) .def("save_factorized_static_scheduling_diagram", &Scheduler::saveFactorizedStaticSchedulingDiagram, py::arg("file_name"), py::arg("min_repeat") = 2) .def("reset_scheduling", &Scheduler::resetScheduling) .def("clear_scheduling", &Scheduler::clearScheduling) .def("generate_scheduling", &Scheduler::generateScheduling) - .def("get_static_scheduling", &Scheduler::getStaticScheduling, py::arg("step") = 0, py::arg("sorting") = Scheduler::EarlyLateSort::Default) + .def("get_static_scheduling", &Scheduler::getStaticScheduling, py::arg("step") = 0) + .def("get_sequential_static_scheduling", &Scheduler::getSequentialStaticScheduling, py::arg("step") = 0, py::arg("sorting") = Scheduler::SchedulingPolicy::Default) + .def("get_scheduling", &Scheduler::getScheduling) .def("graph_view", &Scheduler::graphView) .def("generate_memory", &Scheduler::generateMemory, py::arg("inc_producers") = false, py::arg("wrap_around_buffer") = false) .def("generate_memory_auto_concat", &Scheduler::generateMemoryAutoConcat, py::arg("inc_producers") = false, py::arg("wrap_around_buffer") = false) diff --git a/src/analysis/DynamicAnalysis.cpp b/src/analysis/DynamicAnalysis.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0398201543840ddad106d4eb42a99d4b906b029c --- /dev/null +++ b/src/analysis/DynamicAnalysis.cpp @@ -0,0 +1,57 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/analysis/DynamicAnalysis.hpp" + +#include <cstddef> // std::size_t +#include <memory> +#include <numeric> // std::accumulate +#include <set> + +#include <fmt/core.h> // fmt::println +#include <fmt/format.h> +#include <fmt/ranges.h> + +#include "aidge/data/DataType.hpp" // Aidge::isFloatingPoint +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/scheduler/Scheduler.hpp" + +Aidge::DynamicAnalysis::DynamicAnalysis(const Scheduler& scheduler) + : mScheduler(scheduler) +{ + //ctor +} + +Aidge::DynamicAnalysis::~DynamicAnalysis() = default; + +std::size_t Aidge::DynamicAnalysis::getNbArithmOps() const { return accumulate(&OperatorStats::getNbArithmOps); } +std::size_t Aidge::DynamicAnalysis::getNbLogicOps() const { return accumulate(&OperatorStats::getNbLogicOps); } +std::size_t Aidge::DynamicAnalysis::getNbCompOps() const { return accumulate(&OperatorStats::getNbCompOps); } +std::size_t Aidge::DynamicAnalysis::getNbNLOps() const { return accumulate(&OperatorStats::getNbNLOps); } +std::size_t Aidge::DynamicAnalysis::getNbOps() const { return accumulate(&OperatorStats::getNbOps); } +std::size_t Aidge::DynamicAnalysis::getNbArithmIntOps() const { return accumulate(&OperatorStats::getNbArithmIntOps); } +std::size_t Aidge::DynamicAnalysis::getNbArithmFpOps() const { return accumulate(&OperatorStats::getNbArithmFpOps); } +std::size_t Aidge::DynamicAnalysis::getNbMACOps() const { return accumulate(&OperatorStats::getNbMACOps); } + +std::size_t Aidge::DynamicAnalysis::accumulate(std::size_t (OperatorStats::*func)() const) const { + const auto& scheduling = mScheduler.getScheduling(); + return std::accumulate( + scheduling.cbegin(), + scheduling.cend(), + std::size_t(0), + [this, func](const std::size_t& lhs, const Scheduler::SchedulingElement& rhs) { + return lhs + (OperatorStats::getOpStats(rhs.node).get()->*func)(); + }); +} diff --git a/src/analysis/OperatorStats.cpp b/src/analysis/OperatorStats.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a020403ad826ea6827aa120dc9a074ff770487d2 --- /dev/null +++ b/src/analysis/OperatorStats.cpp @@ -0,0 +1,66 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/analysis/OperatorStats.hpp" +#include "aidge/analysis/StaticAnalysis.hpp" + +#include <cstddef> // std::size_t +#include <memory> +#include <numeric> // std::accumulate +#include <set> + +#include <fmt/core.h> // fmt::println +#include <fmt/format.h> +#include <fmt/ranges.h> + +#include "aidge/data/DataType.hpp" // Aidge::isFloatingPoint +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/operator/OperatorTensor.hpp" + +Aidge::OperatorStats::OperatorStats(const Operator& op) + : mOp(op) +{ + //ctor +} + +std::shared_ptr<Aidge::OperatorStats> Aidge::OperatorStats::getOpStats(std::shared_ptr<Node> node) { + return (Registrar<OperatorStats>::exists(node->type())) + ? Registrar<OperatorStats>::create(node->type())(*(node->getOperator())) + : (node->getOperator()->isAtomic()) + ? std::make_shared<OperatorStats>(*(node->getOperator())) + : std::make_shared<MetaOpStats>(*(node->getOperator())); +} + +Aidge::OperatorStats::~OperatorStats() = default; + +std::size_t Aidge::OperatorStats::getNbArithmIntOps() const { + const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp); + if (opTensor) { + if (!isFloatingPoint(opTensor->getOutput(0)->dataType())) { + return getNbArithmOps(); + } + } + return 0; +} + +//////////////////////////////////////////////////////////////////////////////// + +Aidge::MetaOpStats::~MetaOpStats() = default; + +std::size_t Aidge::MetaOpStats::getNbArithmOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmOps(); } +std::size_t Aidge::MetaOpStats::getNbLogicOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbLogicOps(); } +std::size_t Aidge::MetaOpStats::getNbCompOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbCompOps(); } +std::size_t Aidge::MetaOpStats::getNbNLOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbNLOps(); } +std::size_t Aidge::MetaOpStats::getNbArithmIntOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmIntOps(); } +std::size_t Aidge::MetaOpStats::getNbMACOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbMACOps(); } diff --git a/src/graph/StaticAnalysis.cpp b/src/analysis/StaticAnalysis.cpp similarity index 78% rename from src/graph/StaticAnalysis.cpp rename to src/analysis/StaticAnalysis.cpp index 418ae893631839f6b13c16df422f832fee4615b7..0e32618c2f0a3c8fc3144bab1e80f9c8bac6cf55 100644 --- a/src/graph/StaticAnalysis.cpp +++ b/src/analysis/StaticAnalysis.cpp @@ -9,7 +9,7 @@ * ********************************************************************************/ -#include "aidge/graph/StaticAnalysis.hpp" +#include "aidge/analysis/StaticAnalysis.hpp" #include <cstddef> // std::size_t #include <memory> @@ -27,26 +27,6 @@ #include "aidge/operator/Operator.hpp" #include "aidge/operator/OperatorTensor.hpp" -Aidge::OperatorStats::OperatorStats(const Operator& op) - : mOp(op) -{ - //ctor -} - -Aidge::OperatorStats::~OperatorStats() = default; - -std::size_t Aidge::OperatorStats::getNbArithmIntOps() const { - const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp); - if (opTensor) { - if (!isFloatingPoint(opTensor->getOutput(0)->dataType())) { - return getNbArithmOps(); - } - } - return 0; -} - -//////////////////////////////////////////////////////////////////////////////// - Aidge::StaticAnalysis::StaticAnalysis(std::shared_ptr<GraphView> graph) : mGraph(graph) { @@ -174,14 +154,6 @@ std::size_t Aidge::StaticAnalysis::getParamsSize(std::shared_ptr<Node> node) con return paramsSize; } -std::shared_ptr<Aidge::OperatorStats> Aidge::StaticAnalysis::getOpStats(std::shared_ptr<Node> node) const { - return (Registrar<OperatorStats>::exists(node->type())) - ? Registrar<OperatorStats>::create(node->type())(*(node->getOperator())) - : (node->getOperator()->isAtomic()) - ? std::make_shared<OperatorStats>(*(node->getOperator())) - : std::make_shared<MetaOpStats>(*(node->getOperator())); -} - std::size_t Aidge::StaticAnalysis::getNbArithmOps() const { return accumulate(&OperatorStats::getNbArithmOps); } std::size_t Aidge::StaticAnalysis::getNbLogicOps() const { return accumulate(&OperatorStats::getNbLogicOps); } std::size_t Aidge::StaticAnalysis::getNbCompOps() const { return accumulate(&OperatorStats::getNbCompOps); } @@ -197,17 +169,6 @@ std::size_t Aidge::StaticAnalysis::accumulate(std::size_t (OperatorStats::*func) mGraph->getNodes().cend(), std::size_t(0), [this, func](const std::size_t& lhs, const std::shared_ptr<Node>& rhs) { - return lhs + (this->getOpStats(rhs).get()->*func)(); + return lhs + (OperatorStats::getOpStats(rhs).get()->*func)(); }); } - -//////////////////////////////////////////////////////////////////////////////// - -Aidge::MetaOpStats::~MetaOpStats() = default; - -std::size_t Aidge::MetaOpStats::getNbArithmOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmOps(); } -std::size_t Aidge::MetaOpStats::getNbLogicOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbLogicOps(); } -std::size_t Aidge::MetaOpStats::getNbCompOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbCompOps(); } -std::size_t Aidge::MetaOpStats::getNbNLOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbNLOps(); } -std::size_t Aidge::MetaOpStats::getNbArithmIntOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmIntOps(); } -std::size_t Aidge::MetaOpStats::getNbMACOps() const { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbMACOps(); } diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 99f2de6699090a222ce8c230e586e043c110652c..155a5e7e4689b2e0d645a4288b8a460c0687c395 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -668,7 +668,7 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemory(bool incProducers, bool wr MemoryManager memManager; for (std::size_t step = 0; step < mStaticSchedule.size(); ++step) { - for (const auto& node : getStaticScheduling(step)) { + for (const auto& node : getSequentialStaticScheduling(step)) { if (!incProducers && node->type() == Producer_Op::Type) { memManager.releaseDependencies(node); continue; @@ -787,7 +787,7 @@ Aidge::MemoryManager Aidge::Scheduler::generateMemoryAutoConcat(bool incProducer for (std::size_t step = 0; step < mStaticSchedule.size(); ++step) { // AsLateAsPossible ensures that when a node child is Concat, all the parents // of the Concat parents have already been memory mapped! - for (const auto& node : getStaticScheduling(step, EarlyLateSort::AsLateAsPossible)) { + for (const auto& node : getSequentialStaticScheduling(step, SchedulingPolicy::AsLateAsPossible)) { if (!incProducers && node->type() == Producer_Op::Type) { memManager.releaseDependencies(node); continue; @@ -1038,7 +1038,7 @@ void Aidge::Scheduler::connectInputs(const std::vector<std::shared_ptr<Aidge::Te } } -void Aidge::Scheduler::saveSchedulingDiagram(const std::string& fileName) const { +void Aidge::Scheduler::saveSchedulingDiagram(const std::string& fileName, bool ignoreProducers) const { auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose); if (!fp) { @@ -1054,6 +1054,10 @@ void Aidge::Scheduler::saveSchedulingDiagram(const std::string& fileName) const const auto globalStart = mScheduling[0].start; for (const auto& element : mScheduling) { + if (ignoreProducers && element.node->type() == "Producer") { + continue; + } + auto name = namePtrTable.at(element.node); // Mermaid does not allow : character in task title std::replace(name.begin(), name.end(), ':', '_'); @@ -1068,7 +1072,7 @@ void Aidge::Scheduler::saveSchedulingDiagram(const std::string& fileName) const fmt::print(fp.get(), "\n"); } -void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName) const { +void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName, bool ignoreProducers) const { auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose); if (!fp) { @@ -1084,6 +1088,10 @@ void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName) for (const auto& schedule : mStaticSchedule) { for (const auto& element : schedule) { + if (ignoreProducers && element->node->type() == "Producer") { + continue; + } + auto name = namePtrTable.at(element->node); // Mermaid does not allow : character in task title std::replace(name.begin(), name.end(), ':', '_'); @@ -1154,17 +1162,17 @@ void Aidge::Scheduler::saveFactorizedStaticSchedulingDiagram(const std::string& fmt::print(fp.get(), "\n"); } -std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getStaticScheduling(std::size_t step, EarlyLateSort sorting) const { - AIDGE_ASSERT(!mStaticSchedule.empty(), "Scheduler::getStaticScheduling(): static scheduling is empty, did you generate scheduling first?"); - AIDGE_ASSERT(step < mStaticSchedule.size(), "Scheduler::getStaticScheduling(): no static scheduling at step {} (available steps: {})", mStaticSchedule.size(), step); +std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getSequentialStaticScheduling(std::size_t step, SchedulingPolicy policy) const { + AIDGE_ASSERT(!mStaticSchedule.empty(), "Scheduler::getSequentialStaticScheduling(): static scheduling is empty, did you generate scheduling first?"); + AIDGE_ASSERT(step < mStaticSchedule.size(), "Scheduler::getSequentialStaticScheduling(): no static scheduling at step {} (available steps: {})", mStaticSchedule.size(), step); std::deque<StaticSchedulingElement*> staticSchedule(mStaticSchedule.at(step).begin(), mStaticSchedule.at(step).end()); - if (sorting == EarlyLateSort::AsSoonAsPossible) { + if (policy == SchedulingPolicy::AsSoonAsPossible) { std::stable_sort(staticSchedule.begin(), staticSchedule.end(), [](const auto& lhs, const auto& rhs) { return ((lhs->early < rhs->early) || (lhs->early == rhs->early && lhs->late < rhs->late)); }); } - else if (sorting == EarlyLateSort::AsLateAsPossible) { + else if (policy == SchedulingPolicy::AsLateAsPossible) { // The last condition (lhs->early > rhs->early) ensures that when on a // branch join, one does not switch branch just before the join if there // is only a single node (scheduled as late as possible, since not in the diff --git a/src/scheduler/SequentialScheduler.cpp b/src/scheduler/SequentialScheduler.cpp index 2b1956d790f960124db5c034fa8b4fb790af1d54..07f01ce09888daa1e69e1758d92795cd2b45f124 100644 --- a/src/scheduler/SequentialScheduler.cpp +++ b/src/scheduler/SequentialScheduler.cpp @@ -45,28 +45,18 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, const std::vector<std } // Sort static scheduling according to the policy - std::vector<StaticSchedulingElement*> staticSchedule(mStaticSchedule.at(mStaticScheduleStep).begin(), mStaticSchedule.at(mStaticScheduleStep).end()); - - if (mSchedulingPolicy == SchedulingPolicy::AsSoonAsPossible) { - std::stable_sort(staticSchedule.begin(), staticSchedule.end(), - [](const auto& lhs, const auto& rhs) { return (lhs->early < rhs->early); }); - } - else if (mSchedulingPolicy == SchedulingPolicy::AsLateAsPossible) { - std::stable_sort(staticSchedule.begin(), staticSchedule.end(), - [](const auto& lhs, const auto& rhs) { return (lhs->late < rhs->late); }); - } - + const auto nodes = getSequentialStaticScheduling(mStaticScheduleStep, mSchedulingPolicy); const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); - for (const auto& runnable : staticSchedule) { - const bool skip = !isConditionalNodeRequired(runnable->node); - Log::debug("run: {}{}", namePtrTable.at(runnable->node), (skip) ? " -- skipped" : ""); + for (const auto& runnable : nodes) { + const bool skip = !isConditionalNodeRequired(runnable); + Log::debug("run: {}{}", namePtrTable.at(runnable), (skip) ? " -- skipped" : ""); if (!skip) { const auto tStart = std::chrono::high_resolution_clock::now(); - runnable->node->forward(); + runnable->forward(); const auto tEnd = std::chrono::high_resolution_clock::now(); - mScheduling.push_back(SchedulingElement(runnable->node, tStart, tEnd)); + mScheduling.push_back(SchedulingElement(runnable, tStart, tEnd)); } } @@ -87,17 +77,20 @@ void Aidge::SequentialScheduler::backward() { } // map of node <-> info to display with verbose + const auto nodes = getSequentialStaticScheduling(mStaticScheduleStep, mSchedulingPolicy); const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})"); // run scheduled operators in reverse order - const auto& runnableList = mStaticSchedule.at(mStaticScheduleStep); - for (auto runnable = runnableList.crbegin(); runnable != runnableList.crend(); ++runnable) { - Log::debug("run: {}", namePtrTable.at((*runnable)->node)); - - const auto tStart = std::chrono::high_resolution_clock::now(); - (*runnable)->node->backward(); - const auto tEnd = std::chrono::high_resolution_clock::now(); - mScheduling.push_back(SchedulingElement((*runnable)->node, tStart, tEnd)); + for (auto runnable = nodes.crbegin(); runnable != nodes.crend(); ++runnable) { + const bool skip = !isConditionalNodeRequired((*runnable)); + Log::debug("run: {}{}", namePtrTable.at((*runnable)), (skip) ? " -- skipped" : ""); + + if (!skip) { + const auto tStart = std::chrono::high_resolution_clock::now(); + (*runnable)->backward(); + const auto tEnd = std::chrono::high_resolution_clock::now(); + mScheduling.push_back(SchedulingElement((*runnable), tStart, tEnd)); + } } ++mStaticScheduleStep; diff --git a/unit_tests/graph/Test_StaticAnalysis.cpp b/unit_tests/analysis/Test_StaticAnalysis.cpp similarity index 93% rename from unit_tests/graph/Test_StaticAnalysis.cpp rename to unit_tests/analysis/Test_StaticAnalysis.cpp index 9488cbaf60fffcaee32a573993a46a0a440a4dea..a491cb14361b0660e7e9359c6d3aa26e48a46e12 100644 --- a/unit_tests/graph/Test_StaticAnalysis.cpp +++ b/unit_tests/analysis/Test_StaticAnalysis.cpp @@ -16,7 +16,8 @@ #include "aidge/data/Tensor.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/graph/OpArgs.hpp" -#include "aidge/graph/StaticAnalysis.hpp" +#include "aidge/analysis/OperatorStats.hpp" +#include "aidge/analysis/StaticAnalysis.hpp" #include "aidge/operator/Add.hpp" #include "aidge/operator/BatchNorm.hpp" #include "aidge/operator/FC.hpp" @@ -53,7 +54,7 @@ TEST_CASE("[core/graph] StaticAnalysis") { REQUIRE(stats.getNbParams(g1->getNode("conv2")) == 4 * 8 * 5 * 5 + 8); REQUIRE(stats.getNbParams(g1->getNode("conv3")) == 8 * 16 * 3 * 3 + 16); - const auto conv1Stats = stats.getOpStats(g1->getNode("conv1")); + const auto conv1Stats = OperatorStats::getOpStats(g1->getNode("conv1")); REQUIRE(conv1Stats->getNbMACOps() == 1LL * (16 * 508 * 508) * (5 * 5 * 3 * 4)); REQUIRE(conv1Stats->getNbArithmOps() == 2LL * (16 * 508 * 508) * (5 * 5 * 3 * 4)); REQUIRE(conv1Stats->getNbArithmFpOps() == 2LL * (16 * 508 * 508) * (5 * 5 * 3 * 4)); diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp index 042b04f01bdc1430c8b9f1b9df6951f12b821ed1..cf44280551fb04f4636fa72a23b32a51219020da 100644 --- a/unit_tests/operator/Test_MetaOperator.cpp +++ b/unit_tests/operator/Test_MetaOperator.cpp @@ -58,7 +58,7 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator][MetaOperator]") { //op->getOperator()->updateConsummerProducer(); // require implementation //auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator())->getMicroGraphScheduler(); - //REQUIRE(microGraphScheduler->getStaticScheduling().size() == 2); + //REQUIRE(microGraphScheduler->getSequentialStaticScheduling().size() == 2); } SECTION("LSTM") { diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index ec850d28109a2682bb762c89e814622de6eec3d8..dbe0ef3ae3a6fb253b997064e6a69420846dcaba 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -75,7 +75,7 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { auto scheduler = SequentialScheduler(g1); scheduler.generateScheduling(); fmt::print("gen scheduling finished\n"); - const auto sch = scheduler.getStaticScheduling(); + const auto sch = scheduler.getSequentialStaticScheduling(); const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})"); @@ -118,7 +118,7 @@ TEST_CASE("randomScheduling", "[Scheduler][randomGen]") { // auto scheduler = SequentialScheduler(g1); // scheduler.generateScheduling(); // fmt::print("gen scheduling finished\n"); - // const auto sch = scheduler.getStaticScheduling(); + // const auto sch = scheduler.getSequentialStaticScheduling(); // const auto namePtrTable = g1->getRankedNodesName("{0} ({1}#{3})"); @@ -146,7 +146,7 @@ TEST_CASE("someScheduling", "[Scheduler][someUseCases]") { g->add({data1, identity}); auto scheduler = SequentialScheduler(g); scheduler.generateScheduling(); - const auto sch = scheduler.getStaticScheduling(); + const auto sch = scheduler.getSequentialStaticScheduling(); const auto nodes = g->getNodes(); REQUIRE(sch.size() == nodes.size()); REQUIRE(sch[0] == data1); @@ -159,7 +159,7 @@ TEST_CASE("someScheduling", "[Scheduler][someUseCases]") { g->add({data1}); auto scheduler = SequentialScheduler(g); scheduler.generateScheduling(); - const auto sch = scheduler.getStaticScheduling(); + const auto sch = scheduler.getSequentialStaticScheduling(); const auto nodes = g->getNodes(); REQUIRE(sch.size() == nodes.size()); REQUIRE(sch[0] == data1); @@ -171,7 +171,7 @@ TEST_CASE("someScheduling", "[Scheduler][someUseCases]") { g->add({gen1}); auto scheduler = SequentialScheduler(g); scheduler.generateScheduling(); - const auto sch = scheduler.getStaticScheduling(); + const auto sch = scheduler.getSequentialStaticScheduling(); const auto nodes = g->getNodes(); REQUIRE(sch.size() == nodes.size()); REQUIRE(sch[0] == gen1); @@ -183,7 +183,7 @@ TEST_CASE("someScheduling", "[Scheduler][someUseCases]") { g->add({dead1}); auto scheduler = SequentialScheduler(g); scheduler.generateScheduling(); - const auto sch = scheduler.getStaticScheduling(); + const auto sch = scheduler.getSequentialStaticScheduling(); const auto nodes = g->getNodes(); REQUIRE(nodes.size() == 1); REQUIRE(sch.size() == 0);