diff --git a/aidge_core/simplify_graph.py b/aidge_core/simplify_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..30ee04e6c62c1dfd465d425d1b430d2b5813383b --- /dev/null +++ b/aidge_core/simplify_graph.py @@ -0,0 +1,56 @@ +import numpy as np +import aidge_core + +def simplify_graph(graph: aidge_core.GraphView): + """ + Simplify a graph loaded from ONNX. + + :param graph: The GraphView to simplify. + :type graph: aidge_core.GraphView + """ + + def check_constant_producer(value): + def _check_constant_producer(node): + out = node.get_operator().get_output(0) + return (len(out) == 1 and np.isclose(out[0], value)) + return _check_constant_producer + + gm = aidge_core.SinglePassGraphMatching(graph) + gm.add_node_lambda("Constant_sqrt2", check_constant_producer(np.sqrt(2))) + gm.add_node_lambda("Constant_1", check_constant_producer(1)) + gm.add_node_lambda("Constant_0_5", check_constant_producer(0.5)) + + # Linear [from PyTorch ONNX] + aidge_core.fuse_to_metaops(gm, "MatMul-*>Add", "Linear") + + # LayerNorm [from PyTorch ONNX] + aidge_core.fuse_to_metaops(gm, "ReduceMean-*>Sub#1~>(Pow#1->ReduceMean-*>Add#1->Sqrt)-*>Div#1-*>Mul#1-*>Add#2;" + "Sub#1~*>Div#1;" + "Pow#1<1~Producer;" + "Add#1<*~Producer;" + "Mul#1<*~Producer;" + "Add#2<*~Producer;" + "Sub#1~>$", "LayerNorm") + + # ScaledDotProductAttention [from PyTorch ONNX] + aidge_core.fuse_to_metaops(gm, "MatMul->Div#1->Softmax-*>MatMul;" + "Div#1<1~Producer", "ScaledDotProductAttention") + + # MultiHeadAttention [from PyTorch ONNX] + aidge_core.fuse_to_metaops(gm, "ScaledDotProductAttention#1->Transpose->Reshape#1->Linear;" + "Reshape#1<1~Producer;" + "ScaledDotProductAttention#1<0-(Transpose<-Reshape#2<-Add#1);" + "ScaledDotProductAttention#1<1-(Transpose<-Reshape#3<-Add#2);" + "ScaledDotProductAttention#1<2-(Transpose<-Reshape#4<-Add#3);" + "Reshape#2<1~Producer;" + "Add#1<*-0-Split#1;" + "Add#2<*-1-Split#1;" + "Add#3<*-2-Split#1;" + "Split#1<-MatMul;" + "Split#1<1~Producer", "MultiHeadAttention") + + # GeLU [from PyTorch ONNX] + aidge_core.fuse_to_metaops(gm, "Div#1->Erf->Add#1-*>Mul->Mul#2;" + "Div#1<1~Producer[Constant_sqrt2];" + "Add#1<*~Producer[Constant_1];" + "Mul#2<*~Producer[Constant_0_5]", "GeLU") diff --git a/aidge_core/static_analysis.py b/aidge_core/static_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..c65a102a10601605cd2ca988a2ad3cf2cbd00e6e --- /dev/null +++ b/aidge_core/static_analysis.py @@ -0,0 +1,195 @@ +import matplotlib +import matplotlib.pyplot as plt +from functools import partial +import numpy as np +import aidge_core + +class StaticAnalysisExt(aidge_core.StaticAnalysis): + def log_nb_params(self, filename, title=None, log_scale=False): + namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})"); + nodes = self.get_graph().get_ordered_nodes() + series = [] + legend = None + + for node in nodes: + if node.type() == "Producer": + continue + + name = namePtrTable[node] + series.append([name, self.get_nb_params(node)]) + + if title is None: title = "log_nb_params" + if filename is not None: + self._log_bar(series, filename, title, legend, log_scale) + return series + + def log_params_size(self, filename, title=None, log_scale=False): + namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})"); + nodes = self.get_graph().get_ordered_nodes() + series = [] + legend = None + + for node in nodes: + if node.type() == "Producer": + continue + + name = namePtrTable[node] + series.append([name, self.log_params_size(node)]) + + if title is None: title = "log_params_size" + if filename is not None: + self._log_bar(series, filename, title, legend, log_scale) + return series + + def log_nb_arithm_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_arithm_ops, filename, title, log_scale) + + def log_nb_logic_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_logic_ops, filename, title, log_scale) + + def log_nb_comp_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_comp_ops, filename, title, log_scale) + + def log_nb_nl_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_nl_ops, filename, title, log_scale) + + def log_nb_mac_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_mac_ops, filename, title, log_scale) + + def log_nb_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_ops, filename, title, log_scale) + + def log_nb_arithm_int_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_arithm_int_ops, filename, title, log_scale) + + def log_nb_arithm_fp_ops(self, filename, title=None, log_scale=False): + return self._log_callback(aidge_core.OperatorStats.get_nb_arithm_fp_ops, filename, title, log_scale) + + def log_nb_ops_by_type(self, filename, title=None, log_scale=False): + return self._log_callback([aidge_core.OperatorStats.get_nb_arithm_int_ops, + aidge_core.OperatorStats.get_nb_arithm_fp_ops, + aidge_core.OperatorStats.get_nb_logic_ops, + aidge_core.OperatorStats.get_nb_comp_ops, + aidge_core.OperatorStats.get_nb_nl_ops], filename, title, log_scale) + + def _log_callback(self, callback, filename, title=None, log_scale=False): + """ + Log a statistic given by an OperatorStats callback member function. + Usage: + + stats = StaticAnalysisExt(model) + stats.log_callback(aidge_core.OperatorStats.get_nb_params, "stats.png", "Nb params per operator") + + :param func: OperatorStats member function to call. + :param filename: Output graph file name. + :type filename: str + :param title: Title of the graph. + :type title: str + """ + + namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})"); + nodes = self.get_graph().get_ordered_nodes() + series = [] + legend = None + + for node in nodes: + if node.type() == "Producer": + continue + + stats = self.get_op_stats(node) + name = namePtrTable[node] + attr = {} + if type(node.get_operator()) is aidge_core.GenericOperatorOp: + # Display Generic Op in orange + attr = {'color': 'orange'} + elif not node.get_operator().is_atomic(): + # Display Meta Op in bold + attr = {'fontweight': 'bold'} + elif node.type() not in aidge_core.get_keys_OperatorStats(): + # Display unsupported operator in red labels + attr = {'color': 'red'} + if attr: + name = (name, attr) + if isinstance(callback, list): + series.append([name, [partial(cb, stats)() for cb in callback]]) + legend = [cb.__name__ for cb in callback] + if title is None: title = str(legend) + else: + series.append([name, partial(callback, stats)()]) + if title is None: title = callback.__name__ + + if title is None: title = str(callback) + if filename is not None: + self._log_bar(series, filename, title, legend, log_scale) + return series + + def _log_bar(self, series, filename, title=None, legend=None, log_scale=False): + names, values = zip(*series) + names_only = [item[0] if isinstance(item, tuple) else item for item in names] + fig, ax = plt.subplots(figsize=(max(5, len(names)/4), 5)) + plt.xlim(-0.5, len(names) - 0.5) + if isinstance(values[0], list): + series = [list(i) for i in zip(*values)] + bot = np.zeros(len(series[0])) + for i, serie in enumerate(series): + plt.bar(names_only, serie, bottom=bot) + bot += serie + else: + plt.bar(names_only, values) + ax.yaxis.minorticks_on() + plt.grid(axis='y', which='major', linestyle='--', color='gray') + plt.grid(axis='y', which='minor', linestyle=':', color='lightgray') + formatter0 = matplotlib.ticker.EngFormatter(unit='') + ax.yaxis.set_major_formatter(formatter0) + plt.gca().set_axisbelow(True) + + labels = plt.gca().get_xticks() + tick_labels = plt.gca().get_xticklabels() + for i, label in enumerate(labels): + if isinstance(names[i], tuple): + if 'color' in names[i][1]: + tick_labels[i].set_color(names[i][1]['color']) + elif 'fontweight' in names[i][1]: + tick_labels[i].set_fontweight(names[i][1]['fontweight']) + + plt.xticks(rotation='vertical') + if log_scale: plt.yscale('log') + if title is not None: plt.title(title) + if legend is not None: plt.legend(legend) + plt.savefig(filename, bbox_inches='tight') + + def _log_barh(self, series, filename, title=None, legend=None, log_scale=False): + names, values = zip(*series) + names_only = [item[0] if isinstance(item, tuple) else item for item in names] + fig, ax = plt.subplots(figsize=(10, max(5, len(names)/4))) + plt.ylim(-0.5, len(names) - 0.5) + if isinstance(values[0], list): + series = [list(i) for i in zip(*values)] + left = np.zeros(len(series[0])) + for i, serie in enumerate(series): + plt.barh(names_only, serie, left=left) + left += serie + else: + plt.barh(names_only, values) + ax.xaxis.minorticks_on() + plt.grid(axis='x', which='major', linestyle='--', color='gray') + plt.grid(axis='x', which='minor', linestyle=':', color='lightgray') + formatter0 = matplotlib.ticker.EngFormatter(unit='') + ax.xaxis.set_major_formatter(formatter0) + plt.gca().set_axisbelow(True) + plt.gca().xaxis.set_label_position('top') + plt.gca().xaxis.tick_top() + + labels = plt.gca().get_yticks() + tick_labels = plt.gca().get_yticklabels() + for i, label in enumerate(labels): + if isinstance(names[i], tuple): + if 'color' in names[i][1]: + tick_labels[i].set_color(names[i][1]['color']) + elif 'fontweight' in names[i][1]: + tick_labels[i].set_fontweight(names[i][1]['fontweight']) + + if log_scale: plt.xscale('log') + if title is not None: plt.title(title) + if legend is not None: plt.legend(legend) + plt.savefig(filename, bbox_inches='tight') diff --git a/include/aidge/data/Data.hpp b/include/aidge/data/Data.hpp index 23221e653ba725e4463b06cfabb5483a20756701..6f877194252c7145ea61e1105e0edb0080409d46 100644 --- a/include/aidge/data/Data.hpp +++ b/include/aidge/data/Data.hpp @@ -52,6 +52,9 @@ enum class DataType { Any }; +bool isDataTypeFloatingPoint(const DataType& type); +size_t getDataTypeBitWidth(const DataType& type); + enum class DataFormat { Default, NCHW, diff --git a/include/aidge/graph/Matching.hpp b/include/aidge/graph/Matching.hpp index 951aa6b29d73d9055cf9f13c8ddc6313cb506879..b846af10b87b4088dab7fee41187ded91bf531d1 100644 --- a/include/aidge/graph/Matching.hpp +++ b/include/aidge/graph/Matching.hpp @@ -154,13 +154,13 @@ public: */ std::set<MatchingResult> filterLonguestDisjoint(const std::set<MatchingResult>& matches); - inline void addNodeLambda(const std::string& name, bool(func)(const NodePtr&)) { + inline void addNodeLambda(const std::string& name, std::function<bool(const NodePtr&)> func) { mLambda[name] = func; } private: std::shared_ptr<GraphView> mGraph; - std::map<std::string, bool(*)(const NodePtr&)> mLambda; + std::map<std::string, std::function<bool(const NodePtr&)>> mLambda; /** * QUANTIFIER = '?' | '*' | '+' | ('{' [0-9]+ '}') diff --git a/include/aidge/graph/StaticAnalysis.hpp b/include/aidge/graph/StaticAnalysis.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d92356b72b8f1408c3084f9afa6f467d2043e620 --- /dev/null +++ b/include/aidge/graph/StaticAnalysis.hpp @@ -0,0 +1,571 @@ + +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_GRAPH_STATICANALYSIS_H_ +#define AIDGE_CORE_GRAPH_STATICANALYSIS_H_ + +#include <memory> + +#include "aidge/utils/Registrar.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/operator/OperatorTensor.hpp" + +#include "aidge/operator/AvgPooling.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/ConvDepthWise.hpp" +#include "aidge/operator/FC.hpp" +#include "aidge/operator/MatMul.hpp" +#include "aidge/operator/MaxPooling.hpp" +#include "aidge/operator/ReduceMean.hpp" +#include "aidge/operator/ReduceSum.hpp" +#include "aidge/operator/Softmax.hpp" +#include "aidge/operator/MetaOperator.hpp" + +namespace Aidge { +/** + * @brief Base class to compute statistics from an Operator. + * + */ +class OperatorStats : public Registrable<OperatorStats, std::string, std::function<std::shared_ptr<OperatorStats>(const Operator&)>> { +public: + OperatorStats(const Operator& op); + const Operator& getOperator() const noexcept { return mOp; } + + /** + * @brief Get the worst case total number of arithmetic operations for the + * operator data flow. This includes base arithmetic operations: +, -, / and *. + * Control flow operations (loop counters, index computation...) and memory + * accesses are not included. + * A naive implementation is considered (more operations might be required + * for numerical stability in an actual implementation). + * Example of Operator with only arithmetic operatons: Conv. + * + * @return size_t Number of arithmetic operations. + */ + virtual size_t getNbArithmOps() const { return 2 * getNbMACOps(); }; + + /** + * @brief Get the worst case total number of logic operations for the + * operator data flow. This includes operations like logical shift, or, and... + * Control flow operations (loop counters, index computation...) and memory + * accesses are not included. + * A naive implementation is considered (more operations might be required + * for numerical stability in an actual implementation). + * Example of Operator with only logic operatons: BitShift. + * + * @return size_t Number of logic operations. + */ + virtual size_t getNbLogicOps() const { return 0; }; + + /** + * @brief Get the worst case total number of comparison operations for the + * operator data flow. This includes operations like <, >, =... + * Control flow operations (loop counters, index computation...) and memory + * accesses are not included. + * A naive implementation is considered (more operations might be required + * for numerical stability in an actual implementation). + * Example of Operator with only comparison operatons: MaxPool. + * + * @return size_t Number of comparison operations. + */ + virtual size_t getNbCompOps() const { return 0; }; + + /** + * @brief Get the worst case total number of non-linear (NL) operations for the + * operator data flow. This includes operations like calls to tanh(), erf(), cos()... + * Control flow operations (loop counters, index computation...) and memory + * accesses are not included. + * A naive implementation is considered (more operations might be required + * for numerical stability in an actual implementation). + * Example of Operator with only NL operatons: Tanh. + * Non-linear operations are necessarily of floating-point type. + * + * @return size_t Number of non-linear (NL) operations. + */ + virtual size_t getNbNLOps() const { return 0; }; + + /** + * @brief Get the worst case total number of operations for the operator data flow. + * Total number of operations = arithmetic ops + logic ops + comp ops + NL ops. + * Control flow operations (loop counters, index computation...) and memory + * accesses are not included. + * A naive implementation is considered (more operations might be required + * for numerical stability in an actual implementation). + * + * @return size_t Number of operations. + */ + size_t getNbOps() const { return getNbArithmOps() + getNbLogicOps() + getNbCompOps() + getNbNLOps(); }; + + /** + * @brief Get the worst case total number of INT arithmetic operations for + * the operator data flow. + * Such that getNbArithmOps() = getNbArithmIntOps() + getNbArithmFpOps() + * Control flow operations (loop counters, index computation...) and memory + * accesses are not included. + * A naive implementation is considered (more operations might be required + * for numerical stability in an actual implementation). + * + * @return size_t Number of INT arithmetic operations. + */ + virtual size_t getNbArithmIntOps() const; + + /** + * @brief Get the worst case total number of FP arithmetic operations for + * the operator data flow. + * Such that getNbArithmOps() = getNbArithmIntOps() + getNbArithmFpOps() + * Control flow operations (loop counters, index computation...) and memory + * accesses are not included. + * A naive implementation is considered (more operations might be required + * for numerical stability in an actual implementation). + * + * @return size_t Number of FP arithmetic operations. + */ + size_t getNbArithmFpOps() const { return getNbArithmOps() - getNbArithmIntOps(); }; + + /** + * @brief Get the worst case total number of MAC operations for the operator + * data flow. MAC operations are included in getNbArithmOps(), with 1 MAC + * operation counted as 2 arithmetic operations. MAC can be INT of FP. + * Control flow operations (loop counters, index computation...) and memory + * accesses are not included. + * A naive implementation is considered (more operations might be required + * for numerical stability in an actual implementation). + * + * @return size_t Number of MAC operations. + */ + virtual size_t getNbMACOps() const { return 0; }; + virtual ~OperatorStats() = default; + +protected: + const Operator &mOp; +}; + +/** + * @brief Base class to compute statistics from a GraphView + * + */ +class StaticAnalysis : public std::enable_shared_from_this<StaticAnalysis> { +public: + StaticAnalysis(std::shared_ptr<GraphView> graph); + const std::shared_ptr<GraphView> getGraph() const noexcept { return mGraph; } + + /** + * @brief Get the Operator Stats object corresponding to the given node. + * + * @param node Node + * @return std::shared_ptr<OperatorStats> Node's Operator stats + */ + std::shared_ptr<OperatorStats> getOpStats(std::shared_ptr<Node> node) const; + + /** + * @brief Get the number of parameters associated to a node. This includes + * all Producers directly connected to the node's inputs as well as all + * internal Producers (in case of a meta operator). + * + * Note: this function does not check if parameters are shared between + * several nodes or not. This means that simply adding parameters count from + * several nodes may lead to a higher number of parameters than in reality + * if some of them are shared. + * + * @param node Node + * @return size_t Number of parameters + */ + virtual size_t getNbParams(std::shared_ptr<Node> node) const; + + /** + * @brief Get the total parameters memory size, in bits, associated to a node. + * This includes all Producers directly connected to the node's inputs as + * well as all internal Producers (in case of a meta operator). + * + * Note: this function does not check if parameters are shared between + * several nodes or not. This means that simply adding parameters size from + * several nodes may lead to a higher parameter size than in reality + * if some of them are shared. + * + * @param node Node + * @return size_t Total parameters memory, in bits + */ + virtual size_t getParamsSize(std::shared_ptr<Node> node) const; + + size_t getNbArithmOps() const { return accumulate(&OperatorStats::getNbArithmOps); } + size_t getNbLogicOps() const { return accumulate(&OperatorStats::getNbLogicOps); } + size_t getNbCompOps() const { return accumulate(&OperatorStats::getNbCompOps); } + size_t getNbNLOps() const { return accumulate(&OperatorStats::getNbNLOps); } + size_t getNbOps() const { return accumulate(&OperatorStats::getNbOps); } + size_t getNbArithmIntOps() const { return accumulate(&OperatorStats::getNbArithmIntOps); } + size_t getNbArithmFpOps() const { return accumulate(&OperatorStats::getNbArithmFpOps); } + size_t getNbMACOps() const { return accumulate(&OperatorStats::getNbMACOps); } + virtual void summary(bool incProducers = false) const; + virtual ~StaticAnalysis() = default; + +protected: + const std::shared_ptr<GraphView> mGraph; + + size_t accumulate(size_t (OperatorStats::*func)() const) const; +}; + +//////////////////////////////////////////////////////////////////////////////// + +class MetaOpStats : public OperatorStats { +public: + MetaOpStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<MetaOpStats> create(const Operator& op) { + return std::make_unique<MetaOpStats>(op); + } + + size_t getNbArithmOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmOps(); } + size_t getNbLogicOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbLogicOps(); } + size_t getNbCompOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbCompOps(); } + size_t getNbNLOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbNLOps(); } + size_t getNbArithmIntOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmIntOps(); } + size_t getNbMACOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbMACOps(); } +}; + +template <class OP> +class ConvStats : public OperatorStats { +public: + ConvStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<ConvStats<OP>> create(const Operator& op) { + return std::make_unique<ConvStats<OP>>(op); + } + + size_t getNbMACOps() const override { + const OP& op_ = dynamic_cast<const OP&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); + const std::size_t weightsSize = op_.getInput(1)->size(); + const std::size_t outputSize + = std::accumulate(op_.getOutput(0)->dims().cbegin() + 2, + op_.getOutput(0)->dims().cend(), + 1, + std::multiplies<size_t>()); // NCHW... + const std::size_t batchSize = op_.getInput(0)->dims()[0]; // NCHW + return batchSize * (weightsSize * outputSize); + } +}; + +// Beware: cannot use Conv_Op<2>::Type as key because static variable initialization order is undefined! +REGISTRAR(OperatorStats, "Conv1D", ConvStats<Conv_Op<1>>::create); +REGISTRAR(OperatorStats, "ConvDepthWise1D", ConvStats<ConvDepthWise_Op<1>>::create); +REGISTRAR(OperatorStats, "Conv2D", ConvStats<Conv_Op<2>>::create); +REGISTRAR(OperatorStats, "ConvDepthWise2D", ConvStats<ConvDepthWise_Op<2>>::create); + +template <class OP> +class MaxPoolingStats : public OperatorStats { +public: + MaxPoolingStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<MaxPoolingStats<OP>> create(const Operator& op) { + return std::make_unique<MaxPoolingStats<OP>>(op); + } + + size_t getNbCompOps() const override { + const OP& op_ = dynamic_cast<const OP&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); + const std::size_t poolSize + = std::accumulate(op_.kernelDims().cbegin(), + op_.kernelDims().cend(), + 1, + std::multiplies<size_t>()); + const std::size_t outputSize + = std::accumulate(op_.getOutput(0)->dims().cbegin() + 2, + op_.getOutput(0)->dims().cend(), + 1, + std::multiplies<size_t>()); // NCHW... + const std::size_t batchSize = op_.getInput(0)->dims()[0]; // NCHW + return batchSize * ((poolSize - 1) * outputSize); + } +}; + +REGISTRAR(OperatorStats, "MaxPooling1D", MaxPoolingStats<MaxPooling_Op<1>>::create); +REGISTRAR(OperatorStats, "MaxPooling2D", MaxPoolingStats<MaxPooling_Op<2>>::create); +REGISTRAR(OperatorStats, "MaxPooling3D", MaxPoolingStats<MaxPooling_Op<3>>::create); + +template <class OP> +class AvgPoolingStats : public OperatorStats { +public: + AvgPoolingStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<AvgPoolingStats<OP>> create(const Operator& op) { + return std::make_unique<AvgPoolingStats<OP>>(op); + } + + size_t getNbArithmOps() const override { + const OP& op_ = dynamic_cast<const OP&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); + const std::size_t poolSize + = std::accumulate(op_.kernelDims().cbegin(), + op_.kernelDims().cend(), + 1, + std::multiplies<size_t>()); + const std::size_t outputSize + = std::accumulate(op_.getOutput(0)->dims().cbegin() + 2, + op_.getOutput(0)->dims().cend(), + 1, + std::multiplies<size_t>()); // NCHW... + const std::size_t batchSize = op_.getInput(0)->dims()[0]; // NCHW + // (poolSize - 1) additions + 1 division for each output + return batchSize * (poolSize * outputSize); + } +}; + +REGISTRAR(OperatorStats, "AvgPooling1D", AvgPoolingStats<AvgPooling_Op<1>>::create); +REGISTRAR(OperatorStats, "AvgPooling2D", AvgPoolingStats<AvgPooling_Op<2>>::create); +REGISTRAR(OperatorStats, "AvgPooling3D", AvgPoolingStats<AvgPooling_Op<3>>::create); +REGISTRAR(OperatorStats, "AvgPooling4D", AvgPoolingStats<AvgPooling_Op<4>>::create); + +class FCStats : public OperatorStats { +public: + FCStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<FCStats> create(const Operator& op) { + return std::make_unique<FCStats>(op); + } + + size_t getNbMACOps() const override { + const FC_Op& op_ = dynamic_cast<const FC_Op&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); + const std::size_t weightsSize = op_.getInput(1)->size(); + const std::size_t batchSize = op_.getInput(0)->dims()[0]; // NCHW + return batchSize * weightsSize; + } +}; + +REGISTRAR(OperatorStats, "FC", FCStats::create); + +class MatMulStats : public OperatorStats { +public: + MatMulStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<MatMulStats> create(const Operator& op) { + return std::make_unique<MatMulStats>(op); + } + + size_t getNbMACOps() const override { + const MatMul_Op& op_ = dynamic_cast<const MatMul_Op&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); + const size_t n = (op_.getInput(0)->dims().size() > 1) + ? op_.getInput(0)->dims().end()[-2] : 1; + const size_t k = op_.getInput(0)->dims().back(); + const size_t m = (op_.getInput(1)->dims().size() > 1) + ? op_.getInput(1)->dims().back() : 1; + const size_t nb = (op_.getInput(0)->dims().size() > 2) + ? std::accumulate(op_.getInput(0)->dims().cbegin(), + op_.getInput(0)->dims().cend() - 2, + 1, + std::multiplies<size_t>()) + : 1; + + return nb * n * m * k; + } +}; + +REGISTRAR(OperatorStats, "MatMul", MatMulStats::create); + +class ReLUStats : public OperatorStats { +public: + ReLUStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<ReLUStats> create(const Operator& op) { + return std::make_unique<ReLUStats>(op); + } + + size_t getNbCompOps() const override { + const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); + return op_.getOutput(0)->size(); + } +}; + +REGISTRAR(OperatorStats, "ReLU", ReLUStats::create); + +class AbsStats : public OperatorStats { +public: + AbsStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<AbsStats> create(const Operator& op) { + return std::make_unique<AbsStats>(op); + } + + size_t getNbCompOps() const override { + const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); + return op_.getOutput(0)->size(); + } + + // This is in the worst case (all values are negative) + size_t getNbArithmOps() const override { + const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); + return op_.getOutput(0)->size(); + } +}; + +REGISTRAR(OperatorStats, "Abs", AbsStats::create); + +class ReduceMeanStats : public OperatorStats { +public: + ReduceMeanStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<ReduceMeanStats> create(const Operator& op) { + return std::make_unique<ReduceMeanStats>(op); + } + + size_t getNbArithmOps() const override { + const ReduceMean_Op& op_ = dynamic_cast<const ReduceMean_Op&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); + const size_t nbIn = op_.getInput(0)->size(); + const size_t nbOut = op_.getOutput(0)->size(); + const size_t nbReduce = nbIn / nbOut; + // (nbReduce - 1) additions + 1 division for each output + return nbOut * nbReduce; + } +}; + +REGISTRAR(OperatorStats, "ReduceMean", ReduceMeanStats::create); + +class ReduceSumStats : public OperatorStats { +public: + ReduceSumStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<ReduceSumStats> create(const Operator& op) { + return std::make_unique<ReduceSumStats>(op); + } + + size_t getNbArithmOps() const override { + const ReduceSum_Op& op_ = dynamic_cast<const ReduceSum_Op&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); + const size_t nbIn = op_.getInput(0)->size(); + const size_t nbOut = op_.getOutput(0)->size(); + const size_t nbReduce = nbIn / nbOut; + // (nbReduce - 1) additions for each output + return nbOut * (nbReduce - 1); + } +}; + +REGISTRAR(OperatorStats, "ReduceSum", ReduceSumStats::create); + +class SoftmaxStats : public OperatorStats { +public: + SoftmaxStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<SoftmaxStats> create(const Operator& op) { + return std::make_unique<SoftmaxStats>(op); + } + + size_t getNbArithmOps() const override { + const Softmax_Op& op_ = dynamic_cast<const Softmax_Op&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); + const size_t axis = (op_.axis() >= 0) ? op_.axis() : op_.getInput(0)->nbDims() + op_.axis(); + const size_t nbReduce = op_.getInput(0)->dims()[axis]; + const size_t nbOut = op_.getOutput(0)->size(); + // nbOut divisions + (nbReduce - 1) additions + return nbOut + (nbReduce - 1); + } + + size_t getNbNLOps() const override { + const Softmax_Op& op_ = dynamic_cast<const Softmax_Op&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); + const size_t axis = (op_.axis() >= 0) ? op_.axis() : op_.getInput(0)->nbDims() + op_.axis(); + const size_t nbReduce = op_.getInput(0)->dims()[axis]; + const size_t nbOut = op_.getOutput(0)->size(); + // nbOut exp + nbReduce exp + return nbOut + nbReduce; + } +}; + +REGISTRAR(OperatorStats, "Softmax", SoftmaxStats::create); + +class MemOpStats : public OperatorStats { +public: + MemOpStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<MemOpStats> create(const Operator& op) { + return std::make_unique<MemOpStats>(op); + } +}; + +REGISTRAR(OperatorStats, "Reshape", MemOpStats::create); +REGISTRAR(OperatorStats, "Transpose", MemOpStats::create); +REGISTRAR(OperatorStats, "Concat", MemOpStats::create); +REGISTRAR(OperatorStats, "Split", MemOpStats::create); +REGISTRAR(OperatorStats, "Slice", MemOpStats::create); +REGISTRAR(OperatorStats, "Squeeze", MemOpStats::create); +REGISTRAR(OperatorStats, "Unsqueeze", MemOpStats::create); +REGISTRAR(OperatorStats, "Gather", MemOpStats::create); +REGISTRAR(OperatorStats, "Identity", MemOpStats::create); + +class ElemWiseOpStats : public OperatorStats { +public: + ElemWiseOpStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<ElemWiseOpStats> create(const Operator& op) { + return std::make_unique<ElemWiseOpStats>(op); + } + + size_t getNbArithmOps() const override { + const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); + return op_.getOutput(0)->size(); + } +}; + +REGISTRAR(OperatorStats, "Add", ElemWiseOpStats::create); +REGISTRAR(OperatorStats, "Sub", ElemWiseOpStats::create); +REGISTRAR(OperatorStats, "Mul", ElemWiseOpStats::create); +REGISTRAR(OperatorStats, "Div", ElemWiseOpStats::create); + +class ElemWiseLogicOpStats : public OperatorStats { +public: + ElemWiseLogicOpStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<ElemWiseLogicOpStats> create(const Operator& op) { + return std::make_unique<ElemWiseLogicOpStats>(op); + } + + size_t getNbArithmOps() const override { + const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); + return op_.getOutput(0)->size(); + } +}; + +REGISTRAR(OperatorStats, "And", ElemWiseLogicOpStats::create); + +class ElemWiseNLOpStats : public OperatorStats { +public: + ElemWiseNLOpStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<ElemWiseNLOpStats> create(const Operator& op) { + return std::make_unique<ElemWiseNLOpStats>(op); + } + + size_t getNbNLOps() const override { + const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); + return op_.getOutput(0)->size(); + } +}; + +REGISTRAR(OperatorStats, "Atan", ElemWiseNLOpStats::create); +REGISTRAR(OperatorStats, "Sqrt", ElemWiseNLOpStats::create); +REGISTRAR(OperatorStats, "Erf", ElemWiseNLOpStats::create); +REGISTRAR(OperatorStats, "Ln", ElemWiseNLOpStats::create); +REGISTRAR(OperatorStats, "Sigmoid", ElemWiseNLOpStats::create); +REGISTRAR(OperatorStats, "Tanh", ElemWiseNLOpStats::create); +REGISTRAR(OperatorStats, "Pow", ElemWiseNLOpStats::create); +} + +#endif /* AIDGE_CORE_GRAPH_STATICANALYSIS_H_ */ diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp index 82ecc7d28b723d2b3e268f4fb6fbf20d595240ff..86c722b158657633d4509c1181b1f18201d0d514 100644 --- a/include/aidge/recipes/Recipes.hpp +++ b/include/aidge/recipes/Recipes.hpp @@ -17,7 +17,7 @@ #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" -#include "aidge/graphRegex/matchFsm/MatchResult.hpp" +#include "aidge/graph/Matching.hpp" namespace Aidge { @@ -81,9 +81,6 @@ size_t removeIdentity(std::shared_ptr<GraphView> graph); */ void removeFlatten(std::shared_ptr<Node> flatten); - -void removeFlatten(std::shared_ptr<MatchSolution> solution); - /** * @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node. * @@ -151,6 +148,15 @@ void expandMetaOps(std::shared_ptr<GraphView> graph, bool recursive = false); */ void matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims); +/** + * Fuse each sub-graph matching a query in a Meta Operator. + * @param gm SinglePassGraphMatching containing the graph to manipulate + * @param query Sub-graph matching query + * @param type Type name of the resulting meta operators + * @return size_t Number of replacement +*/ +size_t fuseToMetaOps(SinglePassGraphMatching& gm, const std::string& query, const std::string& type = ""); + /** * Fuse each sub-graph matching a query in a Meta Operator. * @param graph Graph to manipulate diff --git a/include/aidge/utils/DynamicAttributes.hpp b/include/aidge/utils/DynamicAttributes.hpp index 52056852bc454f65f7d12cfc0608e5b6b0b1d933..1b55d7afbf8263a77cf70752fc92f72ef5027904 100644 --- a/include/aidge/utils/DynamicAttributes.hpp +++ b/include/aidge/utils/DynamicAttributes.hpp @@ -166,7 +166,14 @@ public: attributes[elt.first.c_str()] = future_std::any_cast<const DynamicAttributes&>(elt.second).dict(); } else { - attributes[elt.first.c_str()] = mAnyUtils.at(elt.second.type())->cast(elt.second); + // At this point, not every attribute may be known to mAnyUtils + const auto anyUtilsIt = mAnyUtils.find(elt.second.type()); + if (anyUtilsIt != mAnyUtils.end()) { + attributes[elt.first.c_str()] = anyUtilsIt->second->cast(elt.second); + } + else { + attributes[elt.first.c_str()] = "???"; + } } } return attributes; diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index febb6f2ed594174a7aeef60f26b8f9a5ee0e23e3..4b9d2ad545c47971b7c0dff029585bb4c9ae5638 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -149,8 +149,8 @@ void init_GraphView(py::module& m) { // } // }) .def("get_ranked_nodes", &GraphView::getRankedNodes) + .def("get_ranked_nodes_name", &GraphView::getRankedNodesName, py::arg("format"), py::arg("mark_non_unicity") = true) .def("set_dataformat", &GraphView::setDataFormat, py::arg("dataformat")) - ; m.def("get_connected_graph_view", &getConnectedGraphView); diff --git a/python_binding/graph/pybind_Matching.cpp b/python_binding/graph/pybind_Matching.cpp index 94f2471c3f234c1e401484c099a3815dd26d3c30..af385798175afe68d2e89cd65f3645fc7b806eb3 100644 --- a/python_binding/graph/pybind_Matching.cpp +++ b/python_binding/graph/pybind_Matching.cpp @@ -10,6 +10,7 @@ ********************************************************************************/ #include <pybind11/pybind11.h> +#include <pybind11/functional.h> #include <pybind11/stl.h> #include <memory> #include <string> @@ -31,21 +32,20 @@ void init_SinglePassGraphMatching(py::module& m) { py::class_<Aidge::SinglePassGraphMatching>(m, "SinglePassGraphMatching") .def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph")) .def("match", - [](Aidge::SinglePassGraphMatching& self, const std::string& query, bool disjoint){ - // Note: Need to convert set to vector has MatchingResult is not hashable and - // set<MatchingResult> cannot be binded - std::set<Aidge::SinglePassGraphMatching::MatchingResult> set_res = self.match(query, disjoint); - std::vector<Aidge::SinglePassGraphMatching::MatchingResult> vec_res(set_res.begin(), set_res.end()); - return vec_res; - }, - py::arg("query"), py::arg("disjoint") = false, - R"mydelimiter( Matches a query by direct, single-pass parse and match. - :param query: The query string to search. - :param disjoint: If true, only keep the longest disjoint matches. - :return: A set of MatchingResult instances. - )mydelimiter"); - - + [](Aidge::SinglePassGraphMatching& self, const std::string& query, bool disjoint){ + // Note: Need to convert set to vector has MatchingResult is not hashable and + // set<MatchingResult> cannot be binded + std::set<Aidge::SinglePassGraphMatching::MatchingResult> set_res = self.match(query, disjoint); + std::vector<Aidge::SinglePassGraphMatching::MatchingResult> vec_res(set_res.begin(), set_res.end()); + return vec_res; + }, + py::arg("query"), py::arg("disjoint") = false, + R"mydelimiter( Matches a query by direct, single-pass parse and match. + :param query: The query string to search. + :param disjoint: If true, only keep the longest disjoint matches. + :return: A set of MatchingResult instances. + )mydelimiter") + .def("add_node_lambda", &SinglePassGraphMatching::addNodeLambda, py::arg("name"), py::arg("func")); } } // namespace Aidge diff --git a/python_binding/graph/pybind_StaticAnalysis.cpp b/python_binding/graph/pybind_StaticAnalysis.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b7c704d722e81b36e8d4988a4503428918e16a5a --- /dev/null +++ b/python_binding/graph/pybind_StaticAnalysis.cpp @@ -0,0 +1,141 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include "aidge/graph/StaticAnalysis.hpp" + +namespace py = pybind11; +namespace Aidge { + +/** + * @brief Trampoline class for binding + * + */ +class pyOperatorStats: public OperatorStats { +public: + using OperatorStats::OperatorStats; // Inherit constructors + + size_t getNbArithmOps() const override { + PYBIND11_OVERRIDE( + size_t, + OperatorStats, + getNbArithmOps + ); + } + + size_t getNbLogicOps() const override { + PYBIND11_OVERRIDE( + size_t, + OperatorStats, + getNbLogicOps + ); + } + + size_t getNbCompOps() const override { + PYBIND11_OVERRIDE( + size_t, + OperatorStats, + getNbCompOps + ); + } + + size_t getNbNLOps() const override { + PYBIND11_OVERRIDE( + size_t, + OperatorStats, + getNbNLOps + ); + } + + size_t getNbArithmIntOps() const override { + PYBIND11_OVERRIDE( + size_t, + OperatorStats, + getNbArithmIntOps + ); + } + + size_t getNbMACOps() const override { + PYBIND11_OVERRIDE( + size_t, + OperatorStats, + getNbMACOps + ); + } +}; + +class pyStaticAnalysis: public StaticAnalysis { +public: + using StaticAnalysis::StaticAnalysis; // Inherit constructors + + size_t getNbParams(std::shared_ptr<Node> node) const override { + PYBIND11_OVERRIDE( + size_t, + StaticAnalysis, + getNbParams, + node + ); + } + + size_t getParamsSize(std::shared_ptr<Node> node) const override { + PYBIND11_OVERRIDE( + size_t, + StaticAnalysis, + getParamsSize, + node + ); + } + + void summary(bool incProducers) const override { + PYBIND11_OVERRIDE( + void, + StaticAnalysis, + summary, + incProducers + ); + } +}; + +void init_StaticAnalysis(py::module& m){ + py::class_<OperatorStats, std::shared_ptr<OperatorStats>, pyOperatorStats>(m, "OperatorStats", py::multiple_inheritance(), py::dynamic_attr()) + .def(py::init<const Operator&>(), py::arg("op")) + .def("get_operator", &OperatorStats::getOperator) + .def("get_nb_arithm_ops", &OperatorStats::getNbArithmOps) + .def("get_nb_logic_ops", &OperatorStats::getNbLogicOps) + .def("get_nb_comp_ops", &OperatorStats::getNbCompOps) + .def("get_nb_nl_ops", &OperatorStats::getNbNLOps) + .def("get_nb_ops", &OperatorStats::getNbOps) + .def("get_nb_arithm_int_ops", &OperatorStats::getNbArithmIntOps) + .def("get_nb_arithm_fp_ops", &OperatorStats::getNbArithmFpOps) + .def("get_nb_mac_ops", &OperatorStats::getNbMACOps) + ; + declare_registrable<OperatorStats>(m, "OperatorStats"); + + py::class_<StaticAnalysis, std::shared_ptr<StaticAnalysis>, pyStaticAnalysis>(m, "StaticAnalysis", py::multiple_inheritance(), py::dynamic_attr()) + .def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph")) + .def("get_graph", &StaticAnalysis::getGraph) + .def("get_nb_params", &StaticAnalysis::getNbParams, py::arg("node")) + .def("get_params_size", &StaticAnalysis::getParamsSize, py::arg("node")) + .def("get_nb_arithm_ops", &StaticAnalysis::getNbArithmOps) + .def("get_nb_logic_ops", &StaticAnalysis::getNbLogicOps) + .def("get_nb_comp_ops", &StaticAnalysis::getNbCompOps) + .def("get_nb_nl_ops", &StaticAnalysis::getNbNLOps) + .def("get_nb_ops", &StaticAnalysis::getNbOps) + .def("get_nb_arithm_int_ops", &StaticAnalysis::getNbArithmIntOps) + .def("get_nb_arithm_fp_ops", &StaticAnalysis::getNbArithmFpOps) + .def("get_nb_mac_ops", &StaticAnalysis::getNbMACOps) + .def("summary", &StaticAnalysis::summary, py::arg("inc_producers") = false) + .def("get_op_stats", &StaticAnalysis::getOpStats, py::arg("node")) + ; +} +} diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp index e22f88687eff6856ce57fab6621781ffc86873b4..a1d1889c9a1881d3aa7b6eb9ccb4c23c5314cc80 100644 --- a/python_binding/operator/pybind_Operator.cpp +++ b/python_binding/operator/pybind_Operator.cpp @@ -63,6 +63,7 @@ void init_Operator(py::module& m){ .def_property_readonly("attr", &Operator::attributes) .def("set_back_edges", &Operator::setBackEdges, py::arg("input_indexes")) .def("is_back_edge", &Operator::isBackEdge, py::arg("input_index")) + .def("is_atomic", &Operator::isAtomic) ; } } diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index 02f4b732c39ef5cbc1755eb314d32c25c96d01fd..c287314f25c90e6bf8962e31637c9c2990d4db9b 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -27,6 +27,7 @@ void init_OperatorImpl(py::module&); void init_Log(py::module&); void init_Operator(py::module&); void init_OperatorTensor(py::module&); +void init_StaticAnalysis(py::module&); void init_Add(py::module&); void init_And(py::module&); @@ -117,6 +118,7 @@ void init_Aidge(py::module& m) { init_Log(m); init_Operator(m); init_OperatorTensor(m); + init_StaticAnalysis(m); init_Add(m); init_And(m); diff --git a/python_binding/recipes/pybind_Recipes.cpp b/python_binding/recipes/pybind_Recipes.cpp index 6908cbd912b506a7adb7f33a02416d0173174969..77f20b9d655c6d9f6e95b23c4884bd1bc4f9ffd6 100644 --- a/python_binding/recipes/pybind_Recipes.cpp +++ b/python_binding/recipes/pybind_Recipes.cpp @@ -112,7 +112,20 @@ void init_Recipes(py::module &m) :type recursive: bool )mydelimiter"); - m.def("fuse_to_metaops", fuseToMetaOps, py::arg("graph_view"), py::arg("query"), py::arg("type") = "", R"mydelimiter( + m.def("fuse_to_metaops", py::overload_cast<SinglePassGraphMatching&, const std::string&, const std::string&>(fuseToMetaOps), py::arg("gm"), py::arg("query"), py::arg("type") = "", R"mydelimiter( + Fuse each sub-graph matching a query in a Meta Operator. + + :param gm: SinglePassGraphMatching containing the graph to manipulate + :type gm: :py:class:`aidge_core.SinglePassGraphMatching` + :param query: Sub-graph matching query + :type query: str + :param type: Type name of the resulting meta operators + :type type: str, optional + :return: Number of sub-graph actually fused in a Meta Operator. + :rtype: int + )mydelimiter"); + + m.def("fuse_to_metaops", py::overload_cast<std::shared_ptr<GraphView>, const std::string&, const std::string&>(fuseToMetaOps), py::arg("graph_view"), py::arg("query"), py::arg("type") = "", R"mydelimiter( Fuse each sub-graph matching a query in a Meta Operator. :param graph_view: Graph view on which we want to apply the recipe diff --git a/src/data/Data.cpp b/src/data/Data.cpp index 62a883d08a401e02c86408214a061f893ffbfb4a..865b4ff2dcff68753fb5a4cc9cafccdd129e3c8a 100644 --- a/src/data/Data.cpp +++ b/src/data/Data.cpp @@ -11,6 +11,50 @@ #include "aidge/data/Data.hpp" +bool Aidge::isDataTypeFloatingPoint(const DataType& type) { + switch (type) { + case DataType::Float64: + case DataType::Float32: + case DataType::Float16: + case DataType::BFloat16: return true; + default: return false; + } + return false; +} + +size_t Aidge::getDataTypeBitWidth(const DataType& type) { + switch (type) { + case DataType::Float64: return 64; + case DataType::Float32: return 32; + case DataType::Float16: return 16; + case DataType::BFloat16: return 16; + case DataType::Binary: return 1; + case DataType::Ternary: return 2; + case DataType::Int2: return 2; + case DataType::Int3: return 3; + case DataType::Int4: return 4; + case DataType::Int5: return 5; + case DataType::Int6: return 6; + case DataType::Int7: return 7; + case DataType::Int8: return 8; + case DataType::Int16: return 16; + case DataType::Int32: return 32; + case DataType::Int64: return 64; + case DataType::UInt2: return 2; + case DataType::UInt3: return 3; + case DataType::UInt4: return 4; + case DataType::UInt5: return 5; + case DataType::UInt6: return 6; + case DataType::UInt7: return 7; + case DataType::UInt8: return 8; + case DataType::UInt16: return 16; + case DataType::UInt32: return 32; + case DataType::UInt64: return 64; + default: return 0; + } + return 0; +} + Aidge::DataFormatTranspose Aidge::getDataFormatTranspose(const DataFormat& src, const DataFormat& dst) { // Permutation array from default format to src format const auto srcDefToFormat = DataFormatTransposeDict[static_cast<int>(src)]; diff --git a/src/graph/StaticAnalysis.cpp b/src/graph/StaticAnalysis.cpp new file mode 100644 index 0000000000000000000000000000000000000000..033e51022842983caacba9385248c9f02c1e5568 --- /dev/null +++ b/src/graph/StaticAnalysis.cpp @@ -0,0 +1,171 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/graph/StaticAnalysis.hpp" + +Aidge::OperatorStats::OperatorStats(const Operator& op) + : mOp(op) +{ + //ctor +} + +size_t Aidge::OperatorStats::getNbArithmIntOps() const { + const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp); + if (opTensor) { + if (!isDataTypeFloatingPoint(opTensor->getOutput(0)->dataType())) { + return getNbArithmOps(); + } + } + return 0; +} + +Aidge::StaticAnalysis::StaticAnalysis(std::shared_ptr<GraphView> graph) + : mGraph(graph) +{ + //ctor +} + +void Aidge::StaticAnalysis::summary(bool incProducers) const { + fmt::println("--------------------------------------------------------------------------------"); + fmt::println(" Layer (type) Output Shape Param #"); + fmt::println("================================================================================"); + + size_t nbParams = 0; + size_t paramsSize = 0; // Size in bits + size_t fwdBwdSize = 0; // Size in bits + + const auto namePtrTable = mGraph->getRankedNodesName("{0} ({1}#{3})"); + for (const auto node : mGraph->getOrderedNodes()) { + if (node->type() == Producer_Op::Type && !incProducers) { + continue; + } + + auto opTensor = std::dynamic_pointer_cast<OperatorTensor>(node->getOperator()); + std::string outputDimsStr = fmt::format("{: >27}", "?"); + if (opTensor) { + const auto outputDims = opTensor->getOutput(0)->dims(); + outputDimsStr = fmt::format("{: >27}", fmt::format("{}", outputDims)); + + for (size_t out = 0; out < node->nbOutputs(); ++out) { + const auto output = opTensor->getOutput(out); + if (output && node->type() != Producer_Op::Type) { + fwdBwdSize += output->size() + * getDataTypeBitWidth(output->dataType()); + } + } + } + + nbParams += getNbParams(node); + paramsSize += getParamsSize(node); + fmt::println("{: >36}{}{: >16}", + namePtrTable.at(node), outputDimsStr, getNbParams(node)); + } + + size_t inputSize = 0; // Size in bits + for (const auto input : mGraph->getOrderedInputs()) { + if (input.first) { + auto opTensor = std::dynamic_pointer_cast<OperatorTensor>(input.first->getOperator()); + if (opTensor && opTensor->getInput(input.second)) { + inputSize += opTensor->getInput(input.second)->size() + * getDataTypeBitWidth(opTensor->getInput(input.second)->dataType()); + } + } + } + + fmt::println("================================================================================"); + fmt::println("Total params: {}", nbParams); + fmt::println("--------------------------------------------------------------------------------"); + fmt::println("Input size (MB): {}", inputSize / 8.0 / 1024 / 1024); + fmt::println("Forward/backward pass size (MB): {}", fwdBwdSize / 8.0 / 1024 / 1024); + fmt::println("Params size (MB): {}", paramsSize / 8.0 / 1024 / 1024); + fmt::println("Estimated Total Size (MB): {}", (inputSize + fwdBwdSize + paramsSize) / 8.0 / 1024 / 1024); + fmt::println("--------------------------------------------------------------------------------"); +} + +size_t Aidge::StaticAnalysis::getNbParams(std::shared_ptr<Node> node) const { + const auto opTensor = std::dynamic_pointer_cast<OperatorTensor>(node->getOperator()); + + size_t nbParams = 0; + + // Look for Producers directly attached to the node's inputs. + size_t i = 0; + for (auto parent : node->inputs()) { + if (parent.first && mGraph->inView(parent.first)) { + if (parent.first->type() == Producer_Op::Type && opTensor->getInput(i)) { + nbParams += opTensor->getInput(i)->size(); + } + } + ++i; + } + + // Look for internal Producers, in case of meta-op. + if (!node->getOperator()->isAtomic()) { + const auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(node->getOperator())->getMicroGraph(); + for (const auto internalNode : microGraph->getNodes()) { + if (internalNode->type() == Producer_Op::Type) { + const auto internalOpTensor = std::dynamic_pointer_cast<OperatorTensor>(internalNode->getOperator()); + nbParams += internalOpTensor->getOutput(0)->size(); + } + } + } + + return nbParams; +} + +size_t Aidge::StaticAnalysis::getParamsSize(std::shared_ptr<Node> node) const { + const auto opTensor = std::dynamic_pointer_cast<OperatorTensor>(node->getOperator()); + + size_t paramsSize = 0; + + // Look for Producers directly attached to the node's inputs. + size_t i = 0; + for (auto parent : node->inputs()) { + if (parent.first && mGraph->inView(parent.first)) { + if (parent.first->type() == Producer_Op::Type && opTensor->getInput(i)) { + paramsSize += opTensor->getInput(i)->size() + * getDataTypeBitWidth(opTensor->getInput(i)->dataType()); + } + } + ++i; + } + + // Look for internal Producers, in case of meta-op. + if (!node->getOperator()->isAtomic()) { + const auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(node->getOperator())->getMicroGraph(); + for (const auto internalNode : microGraph->getNodes()) { + if (internalNode->type() == Producer_Op::Type) { + const auto internalOpTensor = std::dynamic_pointer_cast<OperatorTensor>(internalNode->getOperator()); + paramsSize += internalOpTensor->getOutput(0)->size() + * getDataTypeBitWidth(internalOpTensor->getOutput(0)->dataType()); + } + } + } + + return paramsSize; +} + +std::shared_ptr<Aidge::OperatorStats> Aidge::StaticAnalysis::getOpStats(std::shared_ptr<Node> node) const { + return (Registrar<OperatorStats>::exists(node->type())) + ? Registrar<OperatorStats>::create(node->type())(*(node->getOperator())) + : (node->getOperator()->isAtomic()) + ? std::make_shared<OperatorStats>(*(node->getOperator())) + : std::make_shared<MetaOpStats>(*(node->getOperator())); +} + +size_t Aidge::StaticAnalysis::accumulate(size_t (OperatorStats::*func)() const) const { + return std::accumulate( + mGraph->getNodes().cbegin(), + mGraph->getNodes().cend(), + std::size_t(0), + [this, func](const size_t& lhs, const std::shared_ptr<Node>& rhs) { + return lhs + (this->getOpStats(rhs).get()->*func)(); + }); +} diff --git a/src/recipes/FuseToMetaOps.cpp b/src/recipes/FuseToMetaOps.cpp index 0ad5e5a1da0e6aef74f7e47751dd2d4e8648980b..ac6536d7e7ec89c5eb1b9efb0c301cfa979739cf 100644 --- a/src/recipes/FuseToMetaOps.cpp +++ b/src/recipes/FuseToMetaOps.cpp @@ -17,9 +17,9 @@ #include "aidge/operator/MetaOperator.hpp" #include "aidge/recipes/Recipes.hpp" -size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::string& query, const std::string& type) { +size_t Aidge::fuseToMetaOps(SinglePassGraphMatching& gm, const std::string& query, const std::string& type) { const auto metaType = (!type.empty()) ? type : query; - const auto matches = SinglePassGraphMatching(graphView).match(query); + const auto matches = gm.match(query); size_t nbReplaced = 0; for (const auto& match : matches) { @@ -48,3 +48,8 @@ size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::str Log::info("Replaced {} (out of {}) matching sub-graph with meta operators", nbReplaced, matches.size()); return nbReplaced; } + +size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::string& query, const std::string& type) { + SinglePassGraphMatching gm(graphView); + return fuseToMetaOps(gm, query, type); +} diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp index d6d98d4701cba900548d127879c9b3940cf1d739..8c5fa222a68a7f2eed329be7c49ca62d0d7ba52f 100644 --- a/unit_tests/graph/Test_Matching.cpp +++ b/unit_tests/graph/Test_Matching.cpp @@ -50,9 +50,9 @@ TEST_CASE("[core/graph] Matching") { ReLU("relu2"), PaddedConv(8, 16, {3, 3}, "conv3", {1, 1}, {2, 2, 2, 2}), ReLU("relu3"), - PaddedConv(8, 16, {5, 5}, "conv4", {1, 1}, {2, 2, 2, 2}), + PaddedConv(16, 16, {5, 5}, "conv4", {1, 1}, {2, 2, 2, 2}), Add("add"), - PaddedConv(8, 16, {5, 5}, "conv5", {1, 1}, {2, 2, 2, 2}), + PaddedConv(16, 16, {5, 5}, "conv5", {1, 1}, {2, 2, 2, 2}), ReLU("relu5"), Add("add2") }); diff --git a/unit_tests/graph/Test_StaticAnalysis.cpp b/unit_tests/graph/Test_StaticAnalysis.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9488cbaf60fffcaee32a573993a46a0a440a4dea --- /dev/null +++ b/unit_tests/graph/Test_StaticAnalysis.cpp @@ -0,0 +1,68 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> + +#include <fmt/chrono.h> + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/OpArgs.hpp" +#include "aidge/graph/StaticAnalysis.hpp" +#include "aidge/operator/Add.hpp" +#include "aidge/operator/BatchNorm.hpp" +#include "aidge/operator/FC.hpp" +#include "aidge/operator/ReLU.hpp" +#include "aidge/operator/MetaOperatorDefs.hpp" +#include "aidge/operator/Producer.hpp" + +using namespace Aidge; + +TEST_CASE("[core/graph] StaticAnalysis") { + SECTION("Conv") { + auto g1 = Sequential({ + Conv(3, 4, {5, 5}, "conv1"), + ReLU("relu1"), + PaddedConv(4, 8, {5, 5}, "conv2", {1, 1}, {2, 2, 2, 2}), + ReLU("relu2"), + PaddedConv(8, 16, {3, 3}, "conv3", {1, 1}, {2, 2, 2, 2}), + ReLU("relu3"), + PaddedConv(16, 16, {5, 5}, "conv4", {1, 1}, {2, 2, 2, 2}), + Add("add"), + PaddedConv(16, 16, {5, 5}, "conv5", {1, 1}, {2, 2, 2, 2}), + ReLU("relu5"), + Add("add2") + }); + + g1->getNode("relu3")->addChild(g1->getNode("add"), 0, 1); + g1->getNode("conv5")->addChild(g1->getNode("add2"), 0, 1); + g1->updateInputsOutputs(); + + g1->forwardDims({{16, 3, 512, 512}}); + + StaticAnalysis stats(g1); + REQUIRE(stats.getNbParams(g1->getNode("conv1")) == 3 * 4 * 5 * 5 + 4); + REQUIRE(stats.getNbParams(g1->getNode("conv2")) == 4 * 8 * 5 * 5 + 8); + REQUIRE(stats.getNbParams(g1->getNode("conv3")) == 8 * 16 * 3 * 3 + 16); + + const auto conv1Stats = stats.getOpStats(g1->getNode("conv1")); + REQUIRE(conv1Stats->getNbMACOps() == 1LL * (16 * 508 * 508) * (5 * 5 * 3 * 4)); + REQUIRE(conv1Stats->getNbArithmOps() == 2LL * (16 * 508 * 508) * (5 * 5 * 3 * 4)); + REQUIRE(conv1Stats->getNbArithmFpOps() == 2LL * (16 * 508 * 508) * (5 * 5 * 3 * 4)); + REQUIRE(conv1Stats->getNbArithmIntOps() == 0); + + g1->getNode("conv1")->getOperator()->setDataType(DataType::Int8); + REQUIRE(conv1Stats->getNbMACOps() == 1LL * (16 * 508 * 508) * (5 * 5 * 3 * 4)); + REQUIRE(conv1Stats->getNbArithmOps() == 2LL * (16 * 508 * 508) * (5 * 5 * 3 * 4)); + REQUIRE(conv1Stats->getNbArithmFpOps() == 0); + REQUIRE(conv1Stats->getNbArithmIntOps() == 2LL * (16 * 508 * 508) * (5 * 5 * 3 * 4)); + } +}