Skip to content
Snippets Groups Projects
Commit 5991c8d6 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'static_analysis' into 'dev'

Initial version of hybrid C++/Python static analysis

See merge request !219
parents 51618b23 07f01dec
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!219Initial version of hybrid C++/Python static analysis
Pipeline #59471 failed
Showing
with 1311 additions and 28 deletions
import numpy as np
import aidge_core
def simplify_graph(graph: aidge_core.GraphView):
"""
Simplify a graph loaded from ONNX.
:param graph: The GraphView to simplify.
:type graph: aidge_core.GraphView
"""
def check_constant_producer(value):
def _check_constant_producer(node):
out = node.get_operator().get_output(0)
return (len(out) == 1 and np.isclose(out[0], value))
return _check_constant_producer
gm = aidge_core.SinglePassGraphMatching(graph)
gm.add_node_lambda("Constant_sqrt2", check_constant_producer(np.sqrt(2)))
gm.add_node_lambda("Constant_1", check_constant_producer(1))
gm.add_node_lambda("Constant_0_5", check_constant_producer(0.5))
# Linear [from PyTorch ONNX]
aidge_core.fuse_to_metaops(gm, "MatMul-*>Add", "Linear")
# LayerNorm [from PyTorch ONNX]
aidge_core.fuse_to_metaops(gm, "ReduceMean-*>Sub#1~>(Pow#1->ReduceMean-*>Add#1->Sqrt)-*>Div#1-*>Mul#1-*>Add#2;"
"Sub#1~*>Div#1;"
"Pow#1<1~Producer;"
"Add#1<*~Producer;"
"Mul#1<*~Producer;"
"Add#2<*~Producer;"
"Sub#1~>$", "LayerNorm")
# ScaledDotProductAttention [from PyTorch ONNX]
aidge_core.fuse_to_metaops(gm, "MatMul->Div#1->Softmax-*>MatMul;"
"Div#1<1~Producer", "ScaledDotProductAttention")
# MultiHeadAttention [from PyTorch ONNX]
aidge_core.fuse_to_metaops(gm, "ScaledDotProductAttention#1->Transpose->Reshape#1->Linear;"
"Reshape#1<1~Producer;"
"ScaledDotProductAttention#1<0-(Transpose<-Reshape#2<-Add#1);"
"ScaledDotProductAttention#1<1-(Transpose<-Reshape#3<-Add#2);"
"ScaledDotProductAttention#1<2-(Transpose<-Reshape#4<-Add#3);"
"Reshape#2<1~Producer;"
"Add#1<*-0-Split#1;"
"Add#2<*-1-Split#1;"
"Add#3<*-2-Split#1;"
"Split#1<-MatMul;"
"Split#1<1~Producer", "MultiHeadAttention")
# GeLU [from PyTorch ONNX]
aidge_core.fuse_to_metaops(gm, "Div#1->Erf->Add#1-*>Mul->Mul#2;"
"Div#1<1~Producer[Constant_sqrt2];"
"Add#1<*~Producer[Constant_1];"
"Mul#2<*~Producer[Constant_0_5]", "GeLU")
import matplotlib
import matplotlib.pyplot as plt
from functools import partial
import numpy as np
import aidge_core
class StaticAnalysisExt(aidge_core.StaticAnalysis):
def log_nb_params(self, filename, title=None, log_scale=False):
namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})");
nodes = self.get_graph().get_ordered_nodes()
series = []
legend = None
for node in nodes:
if node.type() == "Producer":
continue
name = namePtrTable[node]
series.append([name, self.get_nb_params(node)])
if title is None: title = "log_nb_params"
if filename is not None:
self._log_bar(series, filename, title, legend, log_scale)
return series
def log_params_size(self, filename, title=None, log_scale=False):
namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})");
nodes = self.get_graph().get_ordered_nodes()
series = []
legend = None
for node in nodes:
if node.type() == "Producer":
continue
name = namePtrTable[node]
series.append([name, self.log_params_size(node)])
if title is None: title = "log_params_size"
if filename is not None:
self._log_bar(series, filename, title, legend, log_scale)
return series
def log_nb_arithm_ops(self, filename, title=None, log_scale=False):
return self._log_callback(aidge_core.OperatorStats.get_nb_arithm_ops, filename, title, log_scale)
def log_nb_logic_ops(self, filename, title=None, log_scale=False):
return self._log_callback(aidge_core.OperatorStats.get_nb_logic_ops, filename, title, log_scale)
def log_nb_comp_ops(self, filename, title=None, log_scale=False):
return self._log_callback(aidge_core.OperatorStats.get_nb_comp_ops, filename, title, log_scale)
def log_nb_nl_ops(self, filename, title=None, log_scale=False):
return self._log_callback(aidge_core.OperatorStats.get_nb_nl_ops, filename, title, log_scale)
def log_nb_mac_ops(self, filename, title=None, log_scale=False):
return self._log_callback(aidge_core.OperatorStats.get_nb_mac_ops, filename, title, log_scale)
def log_nb_ops(self, filename, title=None, log_scale=False):
return self._log_callback(aidge_core.OperatorStats.get_nb_ops, filename, title, log_scale)
def log_nb_arithm_int_ops(self, filename, title=None, log_scale=False):
return self._log_callback(aidge_core.OperatorStats.get_nb_arithm_int_ops, filename, title, log_scale)
def log_nb_arithm_fp_ops(self, filename, title=None, log_scale=False):
return self._log_callback(aidge_core.OperatorStats.get_nb_arithm_fp_ops, filename, title, log_scale)
def log_nb_ops_by_type(self, filename, title=None, log_scale=False):
return self._log_callback([aidge_core.OperatorStats.get_nb_arithm_int_ops,
aidge_core.OperatorStats.get_nb_arithm_fp_ops,
aidge_core.OperatorStats.get_nb_logic_ops,
aidge_core.OperatorStats.get_nb_comp_ops,
aidge_core.OperatorStats.get_nb_nl_ops], filename, title, log_scale)
def _log_callback(self, callback, filename, title=None, log_scale=False):
"""
Log a statistic given by an OperatorStats callback member function.
Usage:
stats = StaticAnalysisExt(model)
stats.log_callback(aidge_core.OperatorStats.get_nb_params, "stats.png", "Nb params per operator")
:param func: OperatorStats member function to call.
:param filename: Output graph file name.
:type filename: str
:param title: Title of the graph.
:type title: str
"""
namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})");
nodes = self.get_graph().get_ordered_nodes()
series = []
legend = None
for node in nodes:
if node.type() == "Producer":
continue
stats = self.get_op_stats(node)
name = namePtrTable[node]
attr = {}
if type(node.get_operator()) is aidge_core.GenericOperatorOp:
# Display Generic Op in orange
attr = {'color': 'orange'}
elif not node.get_operator().is_atomic():
# Display Meta Op in bold
attr = {'fontweight': 'bold'}
elif node.type() not in aidge_core.get_keys_OperatorStats():
# Display unsupported operator in red labels
attr = {'color': 'red'}
if attr:
name = (name, attr)
if isinstance(callback, list):
series.append([name, [partial(cb, stats)() for cb in callback]])
legend = [cb.__name__ for cb in callback]
if title is None: title = str(legend)
else:
series.append([name, partial(callback, stats)()])
if title is None: title = callback.__name__
if title is None: title = str(callback)
if filename is not None:
self._log_bar(series, filename, title, legend, log_scale)
return series
def _log_bar(self, series, filename, title=None, legend=None, log_scale=False):
names, values = zip(*series)
names_only = [item[0] if isinstance(item, tuple) else item for item in names]
fig, ax = plt.subplots(figsize=(max(5, len(names)/4), 5))
plt.xlim(-0.5, len(names) - 0.5)
if isinstance(values[0], list):
series = [list(i) for i in zip(*values)]
bot = np.zeros(len(series[0]))
for i, serie in enumerate(series):
plt.bar(names_only, serie, bottom=bot)
bot += serie
else:
plt.bar(names_only, values)
ax.yaxis.minorticks_on()
plt.grid(axis='y', which='major', linestyle='--', color='gray')
plt.grid(axis='y', which='minor', linestyle=':', color='lightgray')
formatter0 = matplotlib.ticker.EngFormatter(unit='')
ax.yaxis.set_major_formatter(formatter0)
plt.gca().set_axisbelow(True)
labels = plt.gca().get_xticks()
tick_labels = plt.gca().get_xticklabels()
for i, label in enumerate(labels):
if isinstance(names[i], tuple):
if 'color' in names[i][1]:
tick_labels[i].set_color(names[i][1]['color'])
elif 'fontweight' in names[i][1]:
tick_labels[i].set_fontweight(names[i][1]['fontweight'])
plt.xticks(rotation='vertical')
if log_scale: plt.yscale('log')
if title is not None: plt.title(title)
if legend is not None: plt.legend(legend)
plt.savefig(filename, bbox_inches='tight')
def _log_barh(self, series, filename, title=None, legend=None, log_scale=False):
names, values = zip(*series)
names_only = [item[0] if isinstance(item, tuple) else item for item in names]
fig, ax = plt.subplots(figsize=(10, max(5, len(names)/4)))
plt.ylim(-0.5, len(names) - 0.5)
if isinstance(values[0], list):
series = [list(i) for i in zip(*values)]
left = np.zeros(len(series[0]))
for i, serie in enumerate(series):
plt.barh(names_only, serie, left=left)
left += serie
else:
plt.barh(names_only, values)
ax.xaxis.minorticks_on()
plt.grid(axis='x', which='major', linestyle='--', color='gray')
plt.grid(axis='x', which='minor', linestyle=':', color='lightgray')
formatter0 = matplotlib.ticker.EngFormatter(unit='')
ax.xaxis.set_major_formatter(formatter0)
plt.gca().set_axisbelow(True)
plt.gca().xaxis.set_label_position('top')
plt.gca().xaxis.tick_top()
labels = plt.gca().get_yticks()
tick_labels = plt.gca().get_yticklabels()
for i, label in enumerate(labels):
if isinstance(names[i], tuple):
if 'color' in names[i][1]:
tick_labels[i].set_color(names[i][1]['color'])
elif 'fontweight' in names[i][1]:
tick_labels[i].set_fontweight(names[i][1]['fontweight'])
if log_scale: plt.xscale('log')
if title is not None: plt.title(title)
if legend is not None: plt.legend(legend)
plt.savefig(filename, bbox_inches='tight')
...@@ -52,6 +52,9 @@ enum class DataType { ...@@ -52,6 +52,9 @@ enum class DataType {
Any Any
}; };
bool isDataTypeFloatingPoint(const DataType& type);
size_t getDataTypeBitWidth(const DataType& type);
enum class DataFormat { enum class DataFormat {
Default, Default,
NCHW, NCHW,
......
...@@ -154,13 +154,13 @@ public: ...@@ -154,13 +154,13 @@ public:
*/ */
std::set<MatchingResult> filterLonguestDisjoint(const std::set<MatchingResult>& matches); std::set<MatchingResult> filterLonguestDisjoint(const std::set<MatchingResult>& matches);
inline void addNodeLambda(const std::string& name, bool(func)(const NodePtr&)) { inline void addNodeLambda(const std::string& name, std::function<bool(const NodePtr&)> func) {
mLambda[name] = func; mLambda[name] = func;
} }
private: private:
std::shared_ptr<GraphView> mGraph; std::shared_ptr<GraphView> mGraph;
std::map<std::string, bool(*)(const NodePtr&)> mLambda; std::map<std::string, std::function<bool(const NodePtr&)>> mLambda;
/** /**
* QUANTIFIER = '?' | '*' | '+' | ('{' [0-9]+ '}') * QUANTIFIER = '?' | '*' | '+' | ('{' [0-9]+ '}')
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_GRAPH_STATICANALYSIS_H_
#define AIDGE_CORE_GRAPH_STATICANALYSIS_H_
#include <memory>
#include "aidge/utils/Registrar.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/AvgPooling.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/ConvDepthWise.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/MaxPooling.hpp"
#include "aidge/operator/ReduceMean.hpp"
#include "aidge/operator/ReduceSum.hpp"
#include "aidge/operator/Softmax.hpp"
#include "aidge/operator/MetaOperator.hpp"
namespace Aidge {
/**
* @brief Base class to compute statistics from an Operator.
*
*/
class OperatorStats : public Registrable<OperatorStats, std::string, std::function<std::shared_ptr<OperatorStats>(const Operator&)>> {
public:
OperatorStats(const Operator& op);
const Operator& getOperator() const noexcept { return mOp; }
/**
* @brief Get the worst case total number of arithmetic operations for the
* operator data flow. This includes base arithmetic operations: +, -, / and *.
* Control flow operations (loop counters, index computation...) and memory
* accesses are not included.
* A naive implementation is considered (more operations might be required
* for numerical stability in an actual implementation).
* Example of Operator with only arithmetic operatons: Conv.
*
* @return size_t Number of arithmetic operations.
*/
virtual size_t getNbArithmOps() const { return 2 * getNbMACOps(); };
/**
* @brief Get the worst case total number of logic operations for the
* operator data flow. This includes operations like logical shift, or, and...
* Control flow operations (loop counters, index computation...) and memory
* accesses are not included.
* A naive implementation is considered (more operations might be required
* for numerical stability in an actual implementation).
* Example of Operator with only logic operatons: BitShift.
*
* @return size_t Number of logic operations.
*/
virtual size_t getNbLogicOps() const { return 0; };
/**
* @brief Get the worst case total number of comparison operations for the
* operator data flow. This includes operations like <, >, =...
* Control flow operations (loop counters, index computation...) and memory
* accesses are not included.
* A naive implementation is considered (more operations might be required
* for numerical stability in an actual implementation).
* Example of Operator with only comparison operatons: MaxPool.
*
* @return size_t Number of comparison operations.
*/
virtual size_t getNbCompOps() const { return 0; };
/**
* @brief Get the worst case total number of non-linear (NL) operations for the
* operator data flow. This includes operations like calls to tanh(), erf(), cos()...
* Control flow operations (loop counters, index computation...) and memory
* accesses are not included.
* A naive implementation is considered (more operations might be required
* for numerical stability in an actual implementation).
* Example of Operator with only NL operatons: Tanh.
* Non-linear operations are necessarily of floating-point type.
*
* @return size_t Number of non-linear (NL) operations.
*/
virtual size_t getNbNLOps() const { return 0; };
/**
* @brief Get the worst case total number of operations for the operator data flow.
* Total number of operations = arithmetic ops + logic ops + comp ops + NL ops.
* Control flow operations (loop counters, index computation...) and memory
* accesses are not included.
* A naive implementation is considered (more operations might be required
* for numerical stability in an actual implementation).
*
* @return size_t Number of operations.
*/
size_t getNbOps() const { return getNbArithmOps() + getNbLogicOps() + getNbCompOps() + getNbNLOps(); };
/**
* @brief Get the worst case total number of INT arithmetic operations for
* the operator data flow.
* Such that getNbArithmOps() = getNbArithmIntOps() + getNbArithmFpOps()
* Control flow operations (loop counters, index computation...) and memory
* accesses are not included.
* A naive implementation is considered (more operations might be required
* for numerical stability in an actual implementation).
*
* @return size_t Number of INT arithmetic operations.
*/
virtual size_t getNbArithmIntOps() const;
/**
* @brief Get the worst case total number of FP arithmetic operations for
* the operator data flow.
* Such that getNbArithmOps() = getNbArithmIntOps() + getNbArithmFpOps()
* Control flow operations (loop counters, index computation...) and memory
* accesses are not included.
* A naive implementation is considered (more operations might be required
* for numerical stability in an actual implementation).
*
* @return size_t Number of FP arithmetic operations.
*/
size_t getNbArithmFpOps() const { return getNbArithmOps() - getNbArithmIntOps(); };
/**
* @brief Get the worst case total number of MAC operations for the operator
* data flow. MAC operations are included in getNbArithmOps(), with 1 MAC
* operation counted as 2 arithmetic operations. MAC can be INT of FP.
* Control flow operations (loop counters, index computation...) and memory
* accesses are not included.
* A naive implementation is considered (more operations might be required
* for numerical stability in an actual implementation).
*
* @return size_t Number of MAC operations.
*/
virtual size_t getNbMACOps() const { return 0; };
virtual ~OperatorStats() = default;
protected:
const Operator &mOp;
};
/**
* @brief Base class to compute statistics from a GraphView
*
*/
class StaticAnalysis : public std::enable_shared_from_this<StaticAnalysis> {
public:
StaticAnalysis(std::shared_ptr<GraphView> graph);
const std::shared_ptr<GraphView> getGraph() const noexcept { return mGraph; }
/**
* @brief Get the Operator Stats object corresponding to the given node.
*
* @param node Node
* @return std::shared_ptr<OperatorStats> Node's Operator stats
*/
std::shared_ptr<OperatorStats> getOpStats(std::shared_ptr<Node> node) const;
/**
* @brief Get the number of parameters associated to a node. This includes
* all Producers directly connected to the node's inputs as well as all
* internal Producers (in case of a meta operator).
*
* Note: this function does not check if parameters are shared between
* several nodes or not. This means that simply adding parameters count from
* several nodes may lead to a higher number of parameters than in reality
* if some of them are shared.
*
* @param node Node
* @return size_t Number of parameters
*/
virtual size_t getNbParams(std::shared_ptr<Node> node) const;
/**
* @brief Get the total parameters memory size, in bits, associated to a node.
* This includes all Producers directly connected to the node's inputs as
* well as all internal Producers (in case of a meta operator).
*
* Note: this function does not check if parameters are shared between
* several nodes or not. This means that simply adding parameters size from
* several nodes may lead to a higher parameter size than in reality
* if some of them are shared.
*
* @param node Node
* @return size_t Total parameters memory, in bits
*/
virtual size_t getParamsSize(std::shared_ptr<Node> node) const;
size_t getNbArithmOps() const { return accumulate(&OperatorStats::getNbArithmOps); }
size_t getNbLogicOps() const { return accumulate(&OperatorStats::getNbLogicOps); }
size_t getNbCompOps() const { return accumulate(&OperatorStats::getNbCompOps); }
size_t getNbNLOps() const { return accumulate(&OperatorStats::getNbNLOps); }
size_t getNbOps() const { return accumulate(&OperatorStats::getNbOps); }
size_t getNbArithmIntOps() const { return accumulate(&OperatorStats::getNbArithmIntOps); }
size_t getNbArithmFpOps() const { return accumulate(&OperatorStats::getNbArithmFpOps); }
size_t getNbMACOps() const { return accumulate(&OperatorStats::getNbMACOps); }
virtual void summary(bool incProducers = false) const;
virtual ~StaticAnalysis() = default;
protected:
const std::shared_ptr<GraphView> mGraph;
size_t accumulate(size_t (OperatorStats::*func)() const) const;
};
////////////////////////////////////////////////////////////////////////////////
class MetaOpStats : public OperatorStats {
public:
MetaOpStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<MetaOpStats> create(const Operator& op) {
return std::make_unique<MetaOpStats>(op);
}
size_t getNbArithmOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmOps(); }
size_t getNbLogicOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbLogicOps(); }
size_t getNbCompOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbCompOps(); }
size_t getNbNLOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbNLOps(); }
size_t getNbArithmIntOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmIntOps(); }
size_t getNbMACOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbMACOps(); }
};
template <class OP>
class ConvStats : public OperatorStats {
public:
ConvStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<ConvStats<OP>> create(const Operator& op) {
return std::make_unique<ConvStats<OP>>(op);
}
size_t getNbMACOps() const override {
const OP& op_ = dynamic_cast<const OP&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
const std::size_t weightsSize = op_.getInput(1)->size();
const std::size_t outputSize
= std::accumulate(op_.getOutput(0)->dims().cbegin() + 2,
op_.getOutput(0)->dims().cend(),
1,
std::multiplies<size_t>()); // NCHW...
const std::size_t batchSize = op_.getInput(0)->dims()[0]; // NCHW
return batchSize * (weightsSize * outputSize);
}
};
// Beware: cannot use Conv_Op<2>::Type as key because static variable initialization order is undefined!
REGISTRAR(OperatorStats, "Conv1D", ConvStats<Conv_Op<1>>::create);
REGISTRAR(OperatorStats, "ConvDepthWise1D", ConvStats<ConvDepthWise_Op<1>>::create);
REGISTRAR(OperatorStats, "Conv2D", ConvStats<Conv_Op<2>>::create);
REGISTRAR(OperatorStats, "ConvDepthWise2D", ConvStats<ConvDepthWise_Op<2>>::create);
template <class OP>
class MaxPoolingStats : public OperatorStats {
public:
MaxPoolingStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<MaxPoolingStats<OP>> create(const Operator& op) {
return std::make_unique<MaxPoolingStats<OP>>(op);
}
size_t getNbCompOps() const override {
const OP& op_ = dynamic_cast<const OP&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
const std::size_t poolSize
= std::accumulate(op_.kernelDims().cbegin(),
op_.kernelDims().cend(),
1,
std::multiplies<size_t>());
const std::size_t outputSize
= std::accumulate(op_.getOutput(0)->dims().cbegin() + 2,
op_.getOutput(0)->dims().cend(),
1,
std::multiplies<size_t>()); // NCHW...
const std::size_t batchSize = op_.getInput(0)->dims()[0]; // NCHW
return batchSize * ((poolSize - 1) * outputSize);
}
};
REGISTRAR(OperatorStats, "MaxPooling1D", MaxPoolingStats<MaxPooling_Op<1>>::create);
REGISTRAR(OperatorStats, "MaxPooling2D", MaxPoolingStats<MaxPooling_Op<2>>::create);
REGISTRAR(OperatorStats, "MaxPooling3D", MaxPoolingStats<MaxPooling_Op<3>>::create);
template <class OP>
class AvgPoolingStats : public OperatorStats {
public:
AvgPoolingStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<AvgPoolingStats<OP>> create(const Operator& op) {
return std::make_unique<AvgPoolingStats<OP>>(op);
}
size_t getNbArithmOps() const override {
const OP& op_ = dynamic_cast<const OP&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
const std::size_t poolSize
= std::accumulate(op_.kernelDims().cbegin(),
op_.kernelDims().cend(),
1,
std::multiplies<size_t>());
const std::size_t outputSize
= std::accumulate(op_.getOutput(0)->dims().cbegin() + 2,
op_.getOutput(0)->dims().cend(),
1,
std::multiplies<size_t>()); // NCHW...
const std::size_t batchSize = op_.getInput(0)->dims()[0]; // NCHW
// (poolSize - 1) additions + 1 division for each output
return batchSize * (poolSize * outputSize);
}
};
REGISTRAR(OperatorStats, "AvgPooling1D", AvgPoolingStats<AvgPooling_Op<1>>::create);
REGISTRAR(OperatorStats, "AvgPooling2D", AvgPoolingStats<AvgPooling_Op<2>>::create);
REGISTRAR(OperatorStats, "AvgPooling3D", AvgPoolingStats<AvgPooling_Op<3>>::create);
REGISTRAR(OperatorStats, "AvgPooling4D", AvgPoolingStats<AvgPooling_Op<4>>::create);
class FCStats : public OperatorStats {
public:
FCStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<FCStats> create(const Operator& op) {
return std::make_unique<FCStats>(op);
}
size_t getNbMACOps() const override {
const FC_Op& op_ = dynamic_cast<const FC_Op&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
const std::size_t weightsSize = op_.getInput(1)->size();
const std::size_t batchSize = op_.getInput(0)->dims()[0]; // NCHW
return batchSize * weightsSize;
}
};
REGISTRAR(OperatorStats, "FC", FCStats::create);
class MatMulStats : public OperatorStats {
public:
MatMulStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<MatMulStats> create(const Operator& op) {
return std::make_unique<MatMulStats>(op);
}
size_t getNbMACOps() const override {
const MatMul_Op& op_ = dynamic_cast<const MatMul_Op&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
const size_t n = (op_.getInput(0)->dims().size() > 1)
? op_.getInput(0)->dims().end()[-2] : 1;
const size_t k = op_.getInput(0)->dims().back();
const size_t m = (op_.getInput(1)->dims().size() > 1)
? op_.getInput(1)->dims().back() : 1;
const size_t nb = (op_.getInput(0)->dims().size() > 2)
? std::accumulate(op_.getInput(0)->dims().cbegin(),
op_.getInput(0)->dims().cend() - 2,
1,
std::multiplies<size_t>())
: 1;
return nb * n * m * k;
}
};
REGISTRAR(OperatorStats, "MatMul", MatMulStats::create);
class ReLUStats : public OperatorStats {
public:
ReLUStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<ReLUStats> create(const Operator& op) {
return std::make_unique<ReLUStats>(op);
}
size_t getNbCompOps() const override {
const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
return op_.getOutput(0)->size();
}
};
REGISTRAR(OperatorStats, "ReLU", ReLUStats::create);
class AbsStats : public OperatorStats {
public:
AbsStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<AbsStats> create(const Operator& op) {
return std::make_unique<AbsStats>(op);
}
size_t getNbCompOps() const override {
const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
return op_.getOutput(0)->size();
}
// This is in the worst case (all values are negative)
size_t getNbArithmOps() const override {
const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
return op_.getOutput(0)->size();
}
};
REGISTRAR(OperatorStats, "Abs", AbsStats::create);
class ReduceMeanStats : public OperatorStats {
public:
ReduceMeanStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<ReduceMeanStats> create(const Operator& op) {
return std::make_unique<ReduceMeanStats>(op);
}
size_t getNbArithmOps() const override {
const ReduceMean_Op& op_ = dynamic_cast<const ReduceMean_Op&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
const size_t nbIn = op_.getInput(0)->size();
const size_t nbOut = op_.getOutput(0)->size();
const size_t nbReduce = nbIn / nbOut;
// (nbReduce - 1) additions + 1 division for each output
return nbOut * nbReduce;
}
};
REGISTRAR(OperatorStats, "ReduceMean", ReduceMeanStats::create);
class ReduceSumStats : public OperatorStats {
public:
ReduceSumStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<ReduceSumStats> create(const Operator& op) {
return std::make_unique<ReduceSumStats>(op);
}
size_t getNbArithmOps() const override {
const ReduceSum_Op& op_ = dynamic_cast<const ReduceSum_Op&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
const size_t nbIn = op_.getInput(0)->size();
const size_t nbOut = op_.getOutput(0)->size();
const size_t nbReduce = nbIn / nbOut;
// (nbReduce - 1) additions for each output
return nbOut * (nbReduce - 1);
}
};
REGISTRAR(OperatorStats, "ReduceSum", ReduceSumStats::create);
class SoftmaxStats : public OperatorStats {
public:
SoftmaxStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<SoftmaxStats> create(const Operator& op) {
return std::make_unique<SoftmaxStats>(op);
}
size_t getNbArithmOps() const override {
const Softmax_Op& op_ = dynamic_cast<const Softmax_Op&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
const size_t axis = (op_.axis() >= 0) ? op_.axis() : op_.getInput(0)->nbDims() + op_.axis();
const size_t nbReduce = op_.getInput(0)->dims()[axis];
const size_t nbOut = op_.getOutput(0)->size();
// nbOut divisions + (nbReduce - 1) additions
return nbOut + (nbReduce - 1);
}
size_t getNbNLOps() const override {
const Softmax_Op& op_ = dynamic_cast<const Softmax_Op&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
const size_t axis = (op_.axis() >= 0) ? op_.axis() : op_.getInput(0)->nbDims() + op_.axis();
const size_t nbReduce = op_.getInput(0)->dims()[axis];
const size_t nbOut = op_.getOutput(0)->size();
// nbOut exp + nbReduce exp
return nbOut + nbReduce;
}
};
REGISTRAR(OperatorStats, "Softmax", SoftmaxStats::create);
class MemOpStats : public OperatorStats {
public:
MemOpStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<MemOpStats> create(const Operator& op) {
return std::make_unique<MemOpStats>(op);
}
};
REGISTRAR(OperatorStats, "Reshape", MemOpStats::create);
REGISTRAR(OperatorStats, "Transpose", MemOpStats::create);
REGISTRAR(OperatorStats, "Concat", MemOpStats::create);
REGISTRAR(OperatorStats, "Split", MemOpStats::create);
REGISTRAR(OperatorStats, "Slice", MemOpStats::create);
REGISTRAR(OperatorStats, "Squeeze", MemOpStats::create);
REGISTRAR(OperatorStats, "Unsqueeze", MemOpStats::create);
REGISTRAR(OperatorStats, "Gather", MemOpStats::create);
REGISTRAR(OperatorStats, "Identity", MemOpStats::create);
class ElemWiseOpStats : public OperatorStats {
public:
ElemWiseOpStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<ElemWiseOpStats> create(const Operator& op) {
return std::make_unique<ElemWiseOpStats>(op);
}
size_t getNbArithmOps() const override {
const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
return op_.getOutput(0)->size();
}
};
REGISTRAR(OperatorStats, "Add", ElemWiseOpStats::create);
REGISTRAR(OperatorStats, "Sub", ElemWiseOpStats::create);
REGISTRAR(OperatorStats, "Mul", ElemWiseOpStats::create);
REGISTRAR(OperatorStats, "Div", ElemWiseOpStats::create);
class ElemWiseLogicOpStats : public OperatorStats {
public:
ElemWiseLogicOpStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<ElemWiseLogicOpStats> create(const Operator& op) {
return std::make_unique<ElemWiseLogicOpStats>(op);
}
size_t getNbArithmOps() const override {
const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
return op_.getOutput(0)->size();
}
};
REGISTRAR(OperatorStats, "And", ElemWiseLogicOpStats::create);
class ElemWiseNLOpStats : public OperatorStats {
public:
ElemWiseNLOpStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<ElemWiseNLOpStats> create(const Operator& op) {
return std::make_unique<ElemWiseNLOpStats>(op);
}
size_t getNbNLOps() const override {
const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
return op_.getOutput(0)->size();
}
};
REGISTRAR(OperatorStats, "Atan", ElemWiseNLOpStats::create);
REGISTRAR(OperatorStats, "Sqrt", ElemWiseNLOpStats::create);
REGISTRAR(OperatorStats, "Erf", ElemWiseNLOpStats::create);
REGISTRAR(OperatorStats, "Ln", ElemWiseNLOpStats::create);
REGISTRAR(OperatorStats, "Sigmoid", ElemWiseNLOpStats::create);
REGISTRAR(OperatorStats, "Tanh", ElemWiseNLOpStats::create);
REGISTRAR(OperatorStats, "Pow", ElemWiseNLOpStats::create);
}
#endif /* AIDGE_CORE_GRAPH_STATICANALYSIS_H_ */
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/graphRegex/matchFsm/MatchResult.hpp" #include "aidge/graph/Matching.hpp"
namespace Aidge { namespace Aidge {
...@@ -81,9 +81,6 @@ size_t removeIdentity(std::shared_ptr<GraphView> graph); ...@@ -81,9 +81,6 @@ size_t removeIdentity(std::shared_ptr<GraphView> graph);
*/ */
void removeFlatten(std::shared_ptr<Node> flatten); void removeFlatten(std::shared_ptr<Node> flatten);
void removeFlatten(std::shared_ptr<MatchSolution> solution);
/** /**
* @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node. * @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node.
* *
...@@ -151,6 +148,15 @@ void expandMetaOps(std::shared_ptr<GraphView> graph, bool recursive = false); ...@@ -151,6 +148,15 @@ void expandMetaOps(std::shared_ptr<GraphView> graph, bool recursive = false);
*/ */
void matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims); void matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims);
/**
* Fuse each sub-graph matching a query in a Meta Operator.
* @param gm SinglePassGraphMatching containing the graph to manipulate
* @param query Sub-graph matching query
* @param type Type name of the resulting meta operators
* @return size_t Number of replacement
*/
size_t fuseToMetaOps(SinglePassGraphMatching& gm, const std::string& query, const std::string& type = "");
/** /**
* Fuse each sub-graph matching a query in a Meta Operator. * Fuse each sub-graph matching a query in a Meta Operator.
* @param graph Graph to manipulate * @param graph Graph to manipulate
......
...@@ -166,7 +166,14 @@ public: ...@@ -166,7 +166,14 @@ public:
attributes[elt.first.c_str()] = future_std::any_cast<const DynamicAttributes&>(elt.second).dict(); attributes[elt.first.c_str()] = future_std::any_cast<const DynamicAttributes&>(elt.second).dict();
} }
else { else {
attributes[elt.first.c_str()] = mAnyUtils.at(elt.second.type())->cast(elt.second); // At this point, not every attribute may be known to mAnyUtils
const auto anyUtilsIt = mAnyUtils.find(elt.second.type());
if (anyUtilsIt != mAnyUtils.end()) {
attributes[elt.first.c_str()] = anyUtilsIt->second->cast(elt.second);
}
else {
attributes[elt.first.c_str()] = "???";
}
} }
} }
return attributes; return attributes;
......
...@@ -149,8 +149,8 @@ void init_GraphView(py::module& m) { ...@@ -149,8 +149,8 @@ void init_GraphView(py::module& m) {
// } // }
// }) // })
.def("get_ranked_nodes", &GraphView::getRankedNodes) .def("get_ranked_nodes", &GraphView::getRankedNodes)
.def("get_ranked_nodes_name", &GraphView::getRankedNodesName, py::arg("format"), py::arg("mark_non_unicity") = true)
.def("set_dataformat", &GraphView::setDataFormat, py::arg("dataformat")) .def("set_dataformat", &GraphView::setDataFormat, py::arg("dataformat"))
; ;
m.def("get_connected_graph_view", &getConnectedGraphView); m.def("get_connected_graph_view", &getConnectedGraphView);
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
********************************************************************************/ ********************************************************************************/
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <memory> #include <memory>
#include <string> #include <string>
...@@ -31,21 +32,20 @@ void init_SinglePassGraphMatching(py::module& m) { ...@@ -31,21 +32,20 @@ void init_SinglePassGraphMatching(py::module& m) {
py::class_<Aidge::SinglePassGraphMatching>(m, "SinglePassGraphMatching") py::class_<Aidge::SinglePassGraphMatching>(m, "SinglePassGraphMatching")
.def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph")) .def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph"))
.def("match", .def("match",
[](Aidge::SinglePassGraphMatching& self, const std::string& query, bool disjoint){ [](Aidge::SinglePassGraphMatching& self, const std::string& query, bool disjoint){
// Note: Need to convert set to vector has MatchingResult is not hashable and // Note: Need to convert set to vector has MatchingResult is not hashable and
// set<MatchingResult> cannot be binded // set<MatchingResult> cannot be binded
std::set<Aidge::SinglePassGraphMatching::MatchingResult> set_res = self.match(query, disjoint); std::set<Aidge::SinglePassGraphMatching::MatchingResult> set_res = self.match(query, disjoint);
std::vector<Aidge::SinglePassGraphMatching::MatchingResult> vec_res(set_res.begin(), set_res.end()); std::vector<Aidge::SinglePassGraphMatching::MatchingResult> vec_res(set_res.begin(), set_res.end());
return vec_res; return vec_res;
}, },
py::arg("query"), py::arg("disjoint") = false, py::arg("query"), py::arg("disjoint") = false,
R"mydelimiter( Matches a query by direct, single-pass parse and match. R"mydelimiter( Matches a query by direct, single-pass parse and match.
:param query: The query string to search. :param query: The query string to search.
:param disjoint: If true, only keep the longest disjoint matches. :param disjoint: If true, only keep the longest disjoint matches.
:return: A set of MatchingResult instances. :return: A set of MatchingResult instances.
)mydelimiter"); )mydelimiter")
.def("add_node_lambda", &SinglePassGraphMatching::addNodeLambda, py::arg("name"), py::arg("func"));
} }
} // namespace Aidge } // namespace Aidge
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "aidge/graph/StaticAnalysis.hpp"
namespace py = pybind11;
namespace Aidge {
/**
* @brief Trampoline class for binding
*
*/
class pyOperatorStats: public OperatorStats {
public:
using OperatorStats::OperatorStats; // Inherit constructors
size_t getNbArithmOps() const override {
PYBIND11_OVERRIDE(
size_t,
OperatorStats,
getNbArithmOps
);
}
size_t getNbLogicOps() const override {
PYBIND11_OVERRIDE(
size_t,
OperatorStats,
getNbLogicOps
);
}
size_t getNbCompOps() const override {
PYBIND11_OVERRIDE(
size_t,
OperatorStats,
getNbCompOps
);
}
size_t getNbNLOps() const override {
PYBIND11_OVERRIDE(
size_t,
OperatorStats,
getNbNLOps
);
}
size_t getNbArithmIntOps() const override {
PYBIND11_OVERRIDE(
size_t,
OperatorStats,
getNbArithmIntOps
);
}
size_t getNbMACOps() const override {
PYBIND11_OVERRIDE(
size_t,
OperatorStats,
getNbMACOps
);
}
};
class pyStaticAnalysis: public StaticAnalysis {
public:
using StaticAnalysis::StaticAnalysis; // Inherit constructors
size_t getNbParams(std::shared_ptr<Node> node) const override {
PYBIND11_OVERRIDE(
size_t,
StaticAnalysis,
getNbParams,
node
);
}
size_t getParamsSize(std::shared_ptr<Node> node) const override {
PYBIND11_OVERRIDE(
size_t,
StaticAnalysis,
getParamsSize,
node
);
}
void summary(bool incProducers) const override {
PYBIND11_OVERRIDE(
void,
StaticAnalysis,
summary,
incProducers
);
}
};
void init_StaticAnalysis(py::module& m){
py::class_<OperatorStats, std::shared_ptr<OperatorStats>, pyOperatorStats>(m, "OperatorStats", py::multiple_inheritance(), py::dynamic_attr())
.def(py::init<const Operator&>(), py::arg("op"))
.def("get_operator", &OperatorStats::getOperator)
.def("get_nb_arithm_ops", &OperatorStats::getNbArithmOps)
.def("get_nb_logic_ops", &OperatorStats::getNbLogicOps)
.def("get_nb_comp_ops", &OperatorStats::getNbCompOps)
.def("get_nb_nl_ops", &OperatorStats::getNbNLOps)
.def("get_nb_ops", &OperatorStats::getNbOps)
.def("get_nb_arithm_int_ops", &OperatorStats::getNbArithmIntOps)
.def("get_nb_arithm_fp_ops", &OperatorStats::getNbArithmFpOps)
.def("get_nb_mac_ops", &OperatorStats::getNbMACOps)
;
declare_registrable<OperatorStats>(m, "OperatorStats");
py::class_<StaticAnalysis, std::shared_ptr<StaticAnalysis>, pyStaticAnalysis>(m, "StaticAnalysis", py::multiple_inheritance(), py::dynamic_attr())
.def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph"))
.def("get_graph", &StaticAnalysis::getGraph)
.def("get_nb_params", &StaticAnalysis::getNbParams, py::arg("node"))
.def("get_params_size", &StaticAnalysis::getParamsSize, py::arg("node"))
.def("get_nb_arithm_ops", &StaticAnalysis::getNbArithmOps)
.def("get_nb_logic_ops", &StaticAnalysis::getNbLogicOps)
.def("get_nb_comp_ops", &StaticAnalysis::getNbCompOps)
.def("get_nb_nl_ops", &StaticAnalysis::getNbNLOps)
.def("get_nb_ops", &StaticAnalysis::getNbOps)
.def("get_nb_arithm_int_ops", &StaticAnalysis::getNbArithmIntOps)
.def("get_nb_arithm_fp_ops", &StaticAnalysis::getNbArithmFpOps)
.def("get_nb_mac_ops", &StaticAnalysis::getNbMACOps)
.def("summary", &StaticAnalysis::summary, py::arg("inc_producers") = false)
.def("get_op_stats", &StaticAnalysis::getOpStats, py::arg("node"))
;
}
}
...@@ -63,6 +63,7 @@ void init_Operator(py::module& m){ ...@@ -63,6 +63,7 @@ void init_Operator(py::module& m){
.def_property_readonly("attr", &Operator::attributes) .def_property_readonly("attr", &Operator::attributes)
.def("set_back_edges", &Operator::setBackEdges, py::arg("input_indexes")) .def("set_back_edges", &Operator::setBackEdges, py::arg("input_indexes"))
.def("is_back_edge", &Operator::isBackEdge, py::arg("input_index")) .def("is_back_edge", &Operator::isBackEdge, py::arg("input_index"))
.def("is_atomic", &Operator::isAtomic)
; ;
} }
} }
...@@ -27,6 +27,7 @@ void init_OperatorImpl(py::module&); ...@@ -27,6 +27,7 @@ void init_OperatorImpl(py::module&);
void init_Log(py::module&); void init_Log(py::module&);
void init_Operator(py::module&); void init_Operator(py::module&);
void init_OperatorTensor(py::module&); void init_OperatorTensor(py::module&);
void init_StaticAnalysis(py::module&);
void init_Add(py::module&); void init_Add(py::module&);
void init_And(py::module&); void init_And(py::module&);
...@@ -117,6 +118,7 @@ void init_Aidge(py::module& m) { ...@@ -117,6 +118,7 @@ void init_Aidge(py::module& m) {
init_Log(m); init_Log(m);
init_Operator(m); init_Operator(m);
init_OperatorTensor(m); init_OperatorTensor(m);
init_StaticAnalysis(m);
init_Add(m); init_Add(m);
init_And(m); init_And(m);
......
...@@ -112,7 +112,20 @@ void init_Recipes(py::module &m) ...@@ -112,7 +112,20 @@ void init_Recipes(py::module &m)
:type recursive: bool :type recursive: bool
)mydelimiter"); )mydelimiter");
m.def("fuse_to_metaops", fuseToMetaOps, py::arg("graph_view"), py::arg("query"), py::arg("type") = "", R"mydelimiter( m.def("fuse_to_metaops", py::overload_cast<SinglePassGraphMatching&, const std::string&, const std::string&>(fuseToMetaOps), py::arg("gm"), py::arg("query"), py::arg("type") = "", R"mydelimiter(
Fuse each sub-graph matching a query in a Meta Operator.
:param gm: SinglePassGraphMatching containing the graph to manipulate
:type gm: :py:class:`aidge_core.SinglePassGraphMatching`
:param query: Sub-graph matching query
:type query: str
:param type: Type name of the resulting meta operators
:type type: str, optional
:return: Number of sub-graph actually fused in a Meta Operator.
:rtype: int
)mydelimiter");
m.def("fuse_to_metaops", py::overload_cast<std::shared_ptr<GraphView>, const std::string&, const std::string&>(fuseToMetaOps), py::arg("graph_view"), py::arg("query"), py::arg("type") = "", R"mydelimiter(
Fuse each sub-graph matching a query in a Meta Operator. Fuse each sub-graph matching a query in a Meta Operator.
:param graph_view: Graph view on which we want to apply the recipe :param graph_view: Graph view on which we want to apply the recipe
......
...@@ -11,6 +11,50 @@ ...@@ -11,6 +11,50 @@
#include "aidge/data/Data.hpp" #include "aidge/data/Data.hpp"
bool Aidge::isDataTypeFloatingPoint(const DataType& type) {
switch (type) {
case DataType::Float64:
case DataType::Float32:
case DataType::Float16:
case DataType::BFloat16: return true;
default: return false;
}
return false;
}
size_t Aidge::getDataTypeBitWidth(const DataType& type) {
switch (type) {
case DataType::Float64: return 64;
case DataType::Float32: return 32;
case DataType::Float16: return 16;
case DataType::BFloat16: return 16;
case DataType::Binary: return 1;
case DataType::Ternary: return 2;
case DataType::Int2: return 2;
case DataType::Int3: return 3;
case DataType::Int4: return 4;
case DataType::Int5: return 5;
case DataType::Int6: return 6;
case DataType::Int7: return 7;
case DataType::Int8: return 8;
case DataType::Int16: return 16;
case DataType::Int32: return 32;
case DataType::Int64: return 64;
case DataType::UInt2: return 2;
case DataType::UInt3: return 3;
case DataType::UInt4: return 4;
case DataType::UInt5: return 5;
case DataType::UInt6: return 6;
case DataType::UInt7: return 7;
case DataType::UInt8: return 8;
case DataType::UInt16: return 16;
case DataType::UInt32: return 32;
case DataType::UInt64: return 64;
default: return 0;
}
return 0;
}
Aidge::DataFormatTranspose Aidge::getDataFormatTranspose(const DataFormat& src, const DataFormat& dst) { Aidge::DataFormatTranspose Aidge::getDataFormatTranspose(const DataFormat& src, const DataFormat& dst) {
// Permutation array from default format to src format // Permutation array from default format to src format
const auto srcDefToFormat = DataFormatTransposeDict[static_cast<int>(src)]; const auto srcDefToFormat = DataFormatTransposeDict[static_cast<int>(src)];
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/graph/StaticAnalysis.hpp"
Aidge::OperatorStats::OperatorStats(const Operator& op)
: mOp(op)
{
//ctor
}
size_t Aidge::OperatorStats::getNbArithmIntOps() const {
const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp);
if (opTensor) {
if (!isDataTypeFloatingPoint(opTensor->getOutput(0)->dataType())) {
return getNbArithmOps();
}
}
return 0;
}
Aidge::StaticAnalysis::StaticAnalysis(std::shared_ptr<GraphView> graph)
: mGraph(graph)
{
//ctor
}
void Aidge::StaticAnalysis::summary(bool incProducers) const {
fmt::println("--------------------------------------------------------------------------------");
fmt::println(" Layer (type) Output Shape Param #");
fmt::println("================================================================================");
size_t nbParams = 0;
size_t paramsSize = 0; // Size in bits
size_t fwdBwdSize = 0; // Size in bits
const auto namePtrTable = mGraph->getRankedNodesName("{0} ({1}#{3})");
for (const auto node : mGraph->getOrderedNodes()) {
if (node->type() == Producer_Op::Type && !incProducers) {
continue;
}
auto opTensor = std::dynamic_pointer_cast<OperatorTensor>(node->getOperator());
std::string outputDimsStr = fmt::format("{: >27}", "?");
if (opTensor) {
const auto outputDims = opTensor->getOutput(0)->dims();
outputDimsStr = fmt::format("{: >27}", fmt::format("{}", outputDims));
for (size_t out = 0; out < node->nbOutputs(); ++out) {
const auto output = opTensor->getOutput(out);
if (output && node->type() != Producer_Op::Type) {
fwdBwdSize += output->size()
* getDataTypeBitWidth(output->dataType());
}
}
}
nbParams += getNbParams(node);
paramsSize += getParamsSize(node);
fmt::println("{: >36}{}{: >16}",
namePtrTable.at(node), outputDimsStr, getNbParams(node));
}
size_t inputSize = 0; // Size in bits
for (const auto input : mGraph->getOrderedInputs()) {
if (input.first) {
auto opTensor = std::dynamic_pointer_cast<OperatorTensor>(input.first->getOperator());
if (opTensor && opTensor->getInput(input.second)) {
inputSize += opTensor->getInput(input.second)->size()
* getDataTypeBitWidth(opTensor->getInput(input.second)->dataType());
}
}
}
fmt::println("================================================================================");
fmt::println("Total params: {}", nbParams);
fmt::println("--------------------------------------------------------------------------------");
fmt::println("Input size (MB): {}", inputSize / 8.0 / 1024 / 1024);
fmt::println("Forward/backward pass size (MB): {}", fwdBwdSize / 8.0 / 1024 / 1024);
fmt::println("Params size (MB): {}", paramsSize / 8.0 / 1024 / 1024);
fmt::println("Estimated Total Size (MB): {}", (inputSize + fwdBwdSize + paramsSize) / 8.0 / 1024 / 1024);
fmt::println("--------------------------------------------------------------------------------");
}
size_t Aidge::StaticAnalysis::getNbParams(std::shared_ptr<Node> node) const {
const auto opTensor = std::dynamic_pointer_cast<OperatorTensor>(node->getOperator());
size_t nbParams = 0;
// Look for Producers directly attached to the node's inputs.
size_t i = 0;
for (auto parent : node->inputs()) {
if (parent.first && mGraph->inView(parent.first)) {
if (parent.first->type() == Producer_Op::Type && opTensor->getInput(i)) {
nbParams += opTensor->getInput(i)->size();
}
}
++i;
}
// Look for internal Producers, in case of meta-op.
if (!node->getOperator()->isAtomic()) {
const auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(node->getOperator())->getMicroGraph();
for (const auto internalNode : microGraph->getNodes()) {
if (internalNode->type() == Producer_Op::Type) {
const auto internalOpTensor = std::dynamic_pointer_cast<OperatorTensor>(internalNode->getOperator());
nbParams += internalOpTensor->getOutput(0)->size();
}
}
}
return nbParams;
}
size_t Aidge::StaticAnalysis::getParamsSize(std::shared_ptr<Node> node) const {
const auto opTensor = std::dynamic_pointer_cast<OperatorTensor>(node->getOperator());
size_t paramsSize = 0;
// Look for Producers directly attached to the node's inputs.
size_t i = 0;
for (auto parent : node->inputs()) {
if (parent.first && mGraph->inView(parent.first)) {
if (parent.first->type() == Producer_Op::Type && opTensor->getInput(i)) {
paramsSize += opTensor->getInput(i)->size()
* getDataTypeBitWidth(opTensor->getInput(i)->dataType());
}
}
++i;
}
// Look for internal Producers, in case of meta-op.
if (!node->getOperator()->isAtomic()) {
const auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(node->getOperator())->getMicroGraph();
for (const auto internalNode : microGraph->getNodes()) {
if (internalNode->type() == Producer_Op::Type) {
const auto internalOpTensor = std::dynamic_pointer_cast<OperatorTensor>(internalNode->getOperator());
paramsSize += internalOpTensor->getOutput(0)->size()
* getDataTypeBitWidth(internalOpTensor->getOutput(0)->dataType());
}
}
}
return paramsSize;
}
std::shared_ptr<Aidge::OperatorStats> Aidge::StaticAnalysis::getOpStats(std::shared_ptr<Node> node) const {
return (Registrar<OperatorStats>::exists(node->type()))
? Registrar<OperatorStats>::create(node->type())(*(node->getOperator()))
: (node->getOperator()->isAtomic())
? std::make_shared<OperatorStats>(*(node->getOperator()))
: std::make_shared<MetaOpStats>(*(node->getOperator()));
}
size_t Aidge::StaticAnalysis::accumulate(size_t (OperatorStats::*func)() const) const {
return std::accumulate(
mGraph->getNodes().cbegin(),
mGraph->getNodes().cend(),
std::size_t(0),
[this, func](const size_t& lhs, const std::shared_ptr<Node>& rhs) {
return lhs + (this->getOpStats(rhs).get()->*func)();
});
}
...@@ -17,9 +17,9 @@ ...@@ -17,9 +17,9 @@
#include "aidge/operator/MetaOperator.hpp" #include "aidge/operator/MetaOperator.hpp"
#include "aidge/recipes/Recipes.hpp" #include "aidge/recipes/Recipes.hpp"
size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::string& query, const std::string& type) { size_t Aidge::fuseToMetaOps(SinglePassGraphMatching& gm, const std::string& query, const std::string& type) {
const auto metaType = (!type.empty()) ? type : query; const auto metaType = (!type.empty()) ? type : query;
const auto matches = SinglePassGraphMatching(graphView).match(query); const auto matches = gm.match(query);
size_t nbReplaced = 0; size_t nbReplaced = 0;
for (const auto& match : matches) { for (const auto& match : matches) {
...@@ -48,3 +48,8 @@ size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::str ...@@ -48,3 +48,8 @@ size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::str
Log::info("Replaced {} (out of {}) matching sub-graph with meta operators", nbReplaced, matches.size()); Log::info("Replaced {} (out of {}) matching sub-graph with meta operators", nbReplaced, matches.size());
return nbReplaced; return nbReplaced;
} }
size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::string& query, const std::string& type) {
SinglePassGraphMatching gm(graphView);
return fuseToMetaOps(gm, query, type);
}
...@@ -50,9 +50,9 @@ TEST_CASE("[core/graph] Matching") { ...@@ -50,9 +50,9 @@ TEST_CASE("[core/graph] Matching") {
ReLU("relu2"), ReLU("relu2"),
PaddedConv(8, 16, {3, 3}, "conv3", {1, 1}, {2, 2, 2, 2}), PaddedConv(8, 16, {3, 3}, "conv3", {1, 1}, {2, 2, 2, 2}),
ReLU("relu3"), ReLU("relu3"),
PaddedConv(8, 16, {5, 5}, "conv4", {1, 1}, {2, 2, 2, 2}), PaddedConv(16, 16, {5, 5}, "conv4", {1, 1}, {2, 2, 2, 2}),
Add("add"), Add("add"),
PaddedConv(8, 16, {5, 5}, "conv5", {1, 1}, {2, 2, 2, 2}), PaddedConv(16, 16, {5, 5}, "conv5", {1, 1}, {2, 2, 2, 2}),
ReLU("relu5"), ReLU("relu5"),
Add("add2") Add("add2")
}); });
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <fmt/chrono.h>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/graph/StaticAnalysis.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/BatchNorm.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/operator/MetaOperatorDefs.hpp"
#include "aidge/operator/Producer.hpp"
using namespace Aidge;
TEST_CASE("[core/graph] StaticAnalysis") {
SECTION("Conv") {
auto g1 = Sequential({
Conv(3, 4, {5, 5}, "conv1"),
ReLU("relu1"),
PaddedConv(4, 8, {5, 5}, "conv2", {1, 1}, {2, 2, 2, 2}),
ReLU("relu2"),
PaddedConv(8, 16, {3, 3}, "conv3", {1, 1}, {2, 2, 2, 2}),
ReLU("relu3"),
PaddedConv(16, 16, {5, 5}, "conv4", {1, 1}, {2, 2, 2, 2}),
Add("add"),
PaddedConv(16, 16, {5, 5}, "conv5", {1, 1}, {2, 2, 2, 2}),
ReLU("relu5"),
Add("add2")
});
g1->getNode("relu3")->addChild(g1->getNode("add"), 0, 1);
g1->getNode("conv5")->addChild(g1->getNode("add2"), 0, 1);
g1->updateInputsOutputs();
g1->forwardDims({{16, 3, 512, 512}});
StaticAnalysis stats(g1);
REQUIRE(stats.getNbParams(g1->getNode("conv1")) == 3 * 4 * 5 * 5 + 4);
REQUIRE(stats.getNbParams(g1->getNode("conv2")) == 4 * 8 * 5 * 5 + 8);
REQUIRE(stats.getNbParams(g1->getNode("conv3")) == 8 * 16 * 3 * 3 + 16);
const auto conv1Stats = stats.getOpStats(g1->getNode("conv1"));
REQUIRE(conv1Stats->getNbMACOps() == 1LL * (16 * 508 * 508) * (5 * 5 * 3 * 4));
REQUIRE(conv1Stats->getNbArithmOps() == 2LL * (16 * 508 * 508) * (5 * 5 * 3 * 4));
REQUIRE(conv1Stats->getNbArithmFpOps() == 2LL * (16 * 508 * 508) * (5 * 5 * 3 * 4));
REQUIRE(conv1Stats->getNbArithmIntOps() == 0);
g1->getNode("conv1")->getOperator()->setDataType(DataType::Int8);
REQUIRE(conv1Stats->getNbMACOps() == 1LL * (16 * 508 * 508) * (5 * 5 * 3 * 4));
REQUIRE(conv1Stats->getNbArithmOps() == 2LL * (16 * 508 * 508) * (5 * 5 * 3 * 4));
REQUIRE(conv1Stats->getNbArithmFpOps() == 0);
REQUIRE(conv1Stats->getNbArithmIntOps() == 2LL * (16 * 508 * 508) * (5 * 5 * 3 * 4));
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment