Skip to content
Snippets Groups Projects
Commit 6f272b85 authored by Olivier BICHLER's avatar Olivier BICHLER Committed by Axel Farrugia
Browse files

Improved type of ops and supported ops

parent e07693af
No related branches found
No related tags found
2 merge requests!279v0.4.0,!250[Feat](Exports) Add custom options to exports
......@@ -26,17 +26,27 @@ class StaticAnalysisExt(aidge_core.StaticAnalysis):
def log_nb_comp_ops(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_comp_ops, filename, title, log_scale)
def log_nb_flops(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_flops, filename, title, log_scale)
def log_nb_nl_ops(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_nl_ops, filename, title, log_scale)
def log_nb_mac_ops(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_mac_ops, filename, title, log_scale)
def log_nb_ops(self, filename, title=None, log_scale=False):
self._log_callback([aidge_core.OperatorStats.get_nb_arithm_ops,
self._log_callback(aidge_core.OperatorStats.get_nb_ops, filename, title, log_scale)
def log_nb_arithm_int_ops(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_arithm_int_ops, filename, title, log_scale)
def log_nb_arithm_fp_ops(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_arithm_fp_ops, filename, title, log_scale)
def log_nb_ops_by_type(self, filename, title=None, log_scale=False):
self._log_callback([aidge_core.OperatorStats.get_nb_arithm_int_ops,
aidge_core.OperatorStats.get_nb_arithm_fp_ops,
aidge_core.OperatorStats.get_nb_logic_ops,
aidge_core.OperatorStats.get_nb_comp_ops,
aidge_core.OperatorStats.get_nb_flops], filename, title, log_scale)
def log_nb_mac_ops(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_mac_ops, filename, title, log_scale)
aidge_core.OperatorStats.get_nb_nl_ops], filename, title, log_scale)
def _log_callback(self, callback, filename, title=None, log_scale=False):
"""
......@@ -56,6 +66,7 @@ class StaticAnalysisExt(aidge_core.StaticAnalysis):
namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})");
nodes = self.get_graph().get_ordered_nodes()
series = []
legend = None
for node in nodes:
if node.type() == "Producer":
......@@ -77,14 +88,16 @@ class StaticAnalysisExt(aidge_core.StaticAnalysis):
name = (name, attr)
if isinstance(callback, list):
series.append([name, [partial(cb, stats)() for cb in callback]])
if title is None: title = str([cb.__name__ for cb in callback])
legend = [cb.__name__ for cb in callback]
if title is None: title = str(legend)
else:
series.append([name, partial(callback, stats)()])
if title is None: title = callback.__name__
self._log_bar(series, filename, title, log_scale)
if title is None: title = str(callback)
self._log_bar(series, filename, title, legend, log_scale)
def _log_bar(self, series, filename, title=None, log_scale=False):
def _log_bar(self, series, filename, title=None, legend=None, log_scale=False):
names, values = zip(*series)
names_only = [item[0] if isinstance(item, tuple) else item for item in names]
fig, ax = plt.subplots(figsize=(max(5, len(names)/4), 5))
......@@ -116,4 +129,41 @@ class StaticAnalysisExt(aidge_core.StaticAnalysis):
plt.xticks(rotation='vertical')
if log_scale: plt.yscale('log')
if title is not None: plt.title(title)
if legend is not None: plt.legend(legend)
plt.savefig(filename, bbox_inches='tight')
def _log_barh(self, series, filename, title=None, legend=None, log_scale=False):
names, values = zip(*series)
names_only = [item[0] if isinstance(item, tuple) else item for item in names]
fig, ax = plt.subplots(figsize=(10, max(5, len(names)/4)))
plt.ylim(-0.5, len(names) - 0.5)
if isinstance(values[0], list):
series = [list(i) for i in zip(*values)]
left = np.zeros(len(series[0]))
for i, serie in enumerate(series):
plt.barh(names_only, serie, left=left)
left += serie
else:
plt.barh(names_only, values)
ax.xaxis.minorticks_on()
plt.grid(axis='x', which='major', linestyle='--', color='gray')
plt.grid(axis='x', which='minor', linestyle=':', color='lightgray')
formatter0 = matplotlib.ticker.EngFormatter(unit='')
ax.xaxis.set_major_formatter(formatter0)
plt.gca().set_axisbelow(True)
plt.gca().xaxis.set_label_position('top')
plt.gca().xaxis.tick_top()
labels = plt.gca().get_yticks()
tick_labels = plt.gca().get_yticklabels()
for i, label in enumerate(labels):
if isinstance(names[i], tuple):
if 'color' in names[i][1]:
tick_labels[i].set_color(names[i][1]['color'])
elif 'fontweight' in names[i][1]:
tick_labels[i].set_fontweight(names[i][1]['fontweight'])
if log_scale: plt.xscale('log')
if title is not None: plt.title(title)
if legend is not None: plt.legend(legend)
plt.savefig(filename, bbox_inches='tight')
......@@ -26,6 +26,9 @@
#include "aidge/operator/FC.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/operator/ReduceMean.hpp"
#include "aidge/operator/ReduceSum.hpp"
#include "aidge/operator/Softmax.hpp"
#include "aidge/operator/MetaOperator.hpp"
namespace Aidge {
......@@ -41,11 +44,73 @@ public:
virtual size_t getNbFixedParams() const { return 0; };
virtual size_t getNbTrainableParams() const;
virtual size_t getParamsSize() const;
/**
* @brief Get the total number of arithmetic operations for the operator.
* This includes base arithmetic operations: +, -, / and *.
* Example of Operator with only comparison operatons: Conv.
*
* @return size_t Number of comparison operations.
*/
virtual size_t getNbArithmOps() const { return 2 * getNbMACOps(); };
/**
* @brief Get the total number of logic operations for the operator.
* This includes operations like logical shift, or, and...
* Example of Operator with only comparison operatons: BitShift.
*
* @return size_t Number of comparison operations.
*/
virtual size_t getNbLogicOps() const { return 0; };
/**
* @brief Get the total number of comparison operations for the operator.
* This includes operations like <, >, =...
* Example of Operator with only comparison operatons: MaxPool.
*
* @return size_t Number of comparison operations.
*/
virtual size_t getNbCompOps() const { return 0; };
virtual size_t getNbFlops() const { return 0; };
size_t getNbOps() const { return getNbArithmOps() + getNbLogicOps() + getNbCompOps() + getNbFlops(); };
/**
* @brief Get the total number of non-linear (NL) operations for the operator.
* This includes operations like calls to tanh(), erf(), cos()...
* Example of Operator with only NL operatons: Tanh.
* Non-linear operations are necessarily of floating-point type.
*
* @return size_t Number of non-linear (NL) operations.
*/
virtual size_t getNbNLOps() const { return 0; };
/**
* @brief Get the total number of operations for the operator.
* Total number of operations = arithmetic ops + logic ops + comp ops + NL ops.
*
* @return size_t Number of operations.
*/
size_t getNbOps() const { return getNbArithmOps() + getNbLogicOps() + getNbCompOps() + getNbNLOps(); };
/**
* @brief Get the total number of INT arithmetic operations for the operator.
* Such that getNbArithmOps() = getNbArithmIntOps() + getNbArithmFpOps()
*
* @return size_t Number of INT arithmetic operations.
*/
virtual size_t getNbArithmIntOps() const;
/**
* @brief Get the total number of FP arithmetic operations for the operator.
* Such that getNbArithmOps() = getNbArithmIntOps() + getNbArithmFpOps()
*
* @return size_t Number of FP arithmetic operations.
*/
size_t getNbArithmFpOps() const { return getNbArithmOps() - getNbArithmIntOps(); };
/**
* @brief Get the total number of MAC operations for the operator.
*
* @return size_t Number of MAC operations.
*/
virtual size_t getNbMACOps() const { return 0; };
virtual ~OperatorStats() = default;
......@@ -68,7 +133,10 @@ public:
size_t getNbArithmOps() const { return accumulate(&OperatorStats::getNbArithmOps); }
size_t getNbLogicOps() const { return accumulate(&OperatorStats::getNbLogicOps); }
size_t getNbCompOps() const { return accumulate(&OperatorStats::getNbCompOps); }
size_t getNbFlops() const { return accumulate(&OperatorStats::getNbFlops); }
size_t getNbNLOps() const { return accumulate(&OperatorStats::getNbNLOps); }
size_t getNbOps() const { return accumulate(&OperatorStats::getNbOps); }
size_t getNbArithmIntOps() const { return accumulate(&OperatorStats::getNbArithmIntOps); }
size_t getNbArithmFpOps() const { return accumulate(&OperatorStats::getNbArithmFpOps); }
size_t getNbMACOps() const { return accumulate(&OperatorStats::getNbMACOps); }
virtual void summary(bool incProducers = false) const;
virtual ~StaticAnalysis() = default;
......@@ -96,7 +164,8 @@ public:
size_t getNbArithmOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmOps(); }
size_t getNbLogicOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbLogicOps(); }
size_t getNbCompOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbCompOps(); }
size_t getNbFlops() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbFlops(); }
size_t getNbNLOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbNLOps(); }
size_t getNbArithmIntOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmIntOps(); }
size_t getNbMACOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbMACOps(); }
};
......@@ -112,8 +181,13 @@ public:
size_t getNbMACOps() const override {
const OP& op_ = dynamic_cast<const OP&>(mOp);
const std::size_t weightsSize = op_.getInput(1)->size();
const std::size_t outputSize = op_.getOutput(0)->dims()[2] * op_.getOutput(0)->dims()[3]; // NCHW
return (weightsSize * outputSize);
const std::size_t outputSize
= std::accumulate(op_.getOutput(0)->dims().cbegin() + 2,
op_.getOutput(0)->dims().cend(),
1,
std::multiplies<size_t>()); // NCHW...
const std::size_t batchSize = op_.getInput(0)->dims()[0]; // NCHW
return batchSize * (weightsSize * outputSize);
}
};
......@@ -132,7 +206,8 @@ public:
size_t getNbMACOps() const override {
const FC_Op& op_ = dynamic_cast<const FC_Op&>(mOp);
const std::size_t weightsSize = op_.getInput(1)->size();
return weightsSize;
const std::size_t batchSize = op_.getInput(0)->dims()[0]; // NCHW
return batchSize * weightsSize;
}
};
......@@ -182,6 +257,75 @@ public:
REGISTRAR(OperatorStats, "ReLU", ReLUStats::create);
class ReduceMeanStats : public OperatorStats {
public:
ReduceMeanStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<ReduceMeanStats> create(const Operator& op) {
return std::make_unique<ReduceMeanStats>(op);
}
size_t getNbArithmOps() const override {
const ReduceMean_Op& op_ = dynamic_cast<const ReduceMean_Op&>(mOp);
const size_t nbIn = op_.getInput(0)->size();
const size_t nbOut = op_.getOutput(0)->size();
const size_t nbReduce = nbIn / nbOut;
// (nbReduce - 1) additions + 1 division for each output
return nbOut * nbReduce;
}
};
REGISTRAR(OperatorStats, "ReduceMean", ReduceMeanStats::create);
class ReduceSumStats : public OperatorStats {
public:
ReduceSumStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<ReduceSumStats> create(const Operator& op) {
return std::make_unique<ReduceSumStats>(op);
}
size_t getNbArithmOps() const override {
const ReduceSum_Op& op_ = dynamic_cast<const ReduceSum_Op&>(mOp);
const size_t nbIn = op_.getInput(0)->size();
const size_t nbOut = op_.getOutput(0)->size();
const size_t nbReduce = nbIn / nbOut;
// (nbReduce - 1) additions for each output
return nbOut * (nbReduce - 1);
}
};
REGISTRAR(OperatorStats, "ReduceSum", ReduceSumStats::create);
class SoftmaxStats : public OperatorStats {
public:
SoftmaxStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<SoftmaxStats> create(const Operator& op) {
return std::make_unique<SoftmaxStats>(op);
}
size_t getNbArithmOps() const override {
const Softmax_Op& op_ = dynamic_cast<const Softmax_Op&>(mOp);
const size_t axis = (op_.axis() >= 0) ? op_.axis() : op_.getInput(0)->nbDims() + op_.axis();
const size_t nbReduce = op_.getInput(0)->dims()[axis];
const size_t nbOut = op_.getOutput(0)->size();
// nbOut divisions + (nbReduce - 1) additions
return nbOut + (nbReduce - 1);
}
size_t getNbNLOps() const override {
const Softmax_Op& op_ = dynamic_cast<const Softmax_Op&>(mOp);
const size_t axis = (op_.axis() >= 0) ? op_.axis() : op_.getInput(0)->nbDims() + op_.axis();
const size_t nbReduce = op_.getInput(0)->dims()[axis];
const size_t nbOut = op_.getOutput(0)->size();
// nbOut exp + nbReduce exp
return nbOut + nbReduce;
}
};
REGISTRAR(OperatorStats, "Softmax", SoftmaxStats::create);
class MemOpStats : public OperatorStats {
public:
MemOpStats(const Operator& op) : OperatorStats(op) {}
......@@ -220,25 +364,26 @@ REGISTRAR(OperatorStats, "Sub", ElemWiseOpStats::create);
REGISTRAR(OperatorStats, "Mul", ElemWiseOpStats::create);
REGISTRAR(OperatorStats, "Div", ElemWiseOpStats::create);
class ElemWiseFlopStats : public OperatorStats {
class ElemWiseNLOpStats : public OperatorStats {
public:
ElemWiseFlopStats(const Operator& op) : OperatorStats(op) {}
ElemWiseNLOpStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<ElemWiseFlopStats> create(const Operator& op) {
return std::make_unique<ElemWiseFlopStats>(op);
static std::unique_ptr<ElemWiseNLOpStats> create(const Operator& op) {
return std::make_unique<ElemWiseNLOpStats>(op);
}
size_t getNbFlops() const override {
size_t getNbNLOps() const override {
const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp);
return op_.getOutput(0)->size();
}
};
REGISTRAR(OperatorStats, "Sqrt", ElemWiseFlopStats::create);
REGISTRAR(OperatorStats, "Erf", ElemWiseFlopStats::create);
REGISTRAR(OperatorStats, "Ln", ElemWiseFlopStats::create);
REGISTRAR(OperatorStats, "Sigmoid", ElemWiseFlopStats::create);
REGISTRAR(OperatorStats, "Tanh", ElemWiseFlopStats::create);
REGISTRAR(OperatorStats, "Sqrt", ElemWiseNLOpStats::create);
REGISTRAR(OperatorStats, "Erf", ElemWiseNLOpStats::create);
REGISTRAR(OperatorStats, "Ln", ElemWiseNLOpStats::create);
REGISTRAR(OperatorStats, "Sigmoid", ElemWiseNLOpStats::create);
REGISTRAR(OperatorStats, "Tanh", ElemWiseNLOpStats::create);
REGISTRAR(OperatorStats, "Pow", ElemWiseNLOpStats::create);
}
#endif /* AIDGE_CORE_GRAPH_STATICANALYSIS_H_ */
......@@ -73,19 +73,27 @@ public:
);
}
size_t getNbMACOps() const override {
size_t getNbNLOps() const override {
PYBIND11_OVERRIDE(
size_t,
OperatorStats,
getNbMACOps
getNbNLOps
);
}
size_t getNbArithmIntOps() const override {
PYBIND11_OVERRIDE(
size_t,
OperatorStats,
getNbArithmIntOps
);
}
size_t getNbFlops() const override {
size_t getNbMACOps() const override {
PYBIND11_OVERRIDE(
size_t,
OperatorStats,
getNbFlops
getNbMACOps
);
}
};
......@@ -121,8 +129,11 @@ void init_StaticAnalysis(py::module& m){
.def("get_nb_arithm_ops", &OperatorStats::getNbArithmOps)
.def("get_nb_logic_ops", &OperatorStats::getNbLogicOps)
.def("get_nb_comp_ops", &OperatorStats::getNbCompOps)
.def("get_nb_nl_ops", &OperatorStats::getNbNLOps)
.def("get_nb_ops", &OperatorStats::getNbOps)
.def("get_nb_arithm_int_ops", &OperatorStats::getNbArithmIntOps)
.def("get_nb_arithm_fp_ops", &OperatorStats::getNbArithmFpOps)
.def("get_nb_mac_ops", &OperatorStats::getNbMACOps)
.def("get_nb_flops", &OperatorStats::getNbFlops)
;
declare_registrable<OperatorStats>(m, "OperatorStats");
......@@ -136,8 +147,11 @@ void init_StaticAnalysis(py::module& m){
.def("get_nb_arithm_ops", &StaticAnalysis::getNbArithmOps)
.def("get_nb_logic_ops", &StaticAnalysis::getNbLogicOps)
.def("get_nb_comp_ops", &StaticAnalysis::getNbCompOps)
.def("get_nb_nl_ops", &StaticAnalysis::getNbNLOps)
.def("get_nb_ops", &StaticAnalysis::getNbOps)
.def("get_nb_arithm_int_ops", &StaticAnalysis::getNbArithmIntOps)
.def("get_nb_arithm_fp_ops", &StaticAnalysis::getNbArithmFpOps)
.def("get_nb_mac_ops", &StaticAnalysis::getNbMACOps)
.def("get_nb_flops", &StaticAnalysis::getNbFlops)
.def("summary", &StaticAnalysis::summary, py::arg("inc_producers") = false)
.def("get_op_stats", &StaticAnalysis_Publicist::getOpStats, py::arg("node"))
;
......
......@@ -53,6 +53,16 @@ size_t Aidge::OperatorStats::getParamsSize() const {
return paramsSize;
}
size_t Aidge::OperatorStats::getNbArithmIntOps() const {
const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp);
if (opTensor) {
if (!isDataTypeFloatingPoint(opTensor->getOutput(0)->dataType())) {
return getNbArithmOps();
}
}
return 0;
}
Aidge::StaticAnalysis::StaticAnalysis(std::shared_ptr<GraphView> graph)
: mGraph(graph)
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment