diff --git a/aidge_core/static_analysis.py b/aidge_core/static_analysis.py index 92922becb023aa51378dd5ac09cffb4fe05c7c0b..cb288d06a2c297ce3106c9dd44e98380086c9fe9 100644 --- a/aidge_core/static_analysis.py +++ b/aidge_core/static_analysis.py @@ -1,6 +1,7 @@ import matplotlib import matplotlib.pyplot as plt from functools import partial +import numpy as np import aidge_core class StaticAnalysisExt(aidge_core.StaticAnalysis): @@ -25,12 +26,18 @@ class StaticAnalysisExt(aidge_core.StaticAnalysis): def log_nb_comp_ops(self, filename, title=None, log_scale=False): self._log_callback(aidge_core.OperatorStats.get_nb_comp_ops, filename, title, log_scale) - def log_nb_mac_ops(self, filename, title=None, log_scale=False): - self._log_callback(aidge_core.OperatorStats.get_nb_mac_ops, filename, title, log_scale) - def log_nb_flops(self, filename, title=None, log_scale=False): self._log_callback(aidge_core.OperatorStats.get_nb_flops, filename, title, log_scale) + def log_nb_ops(self, filename, title=None, log_scale=False): + self._log_callback([aidge_core.OperatorStats.get_nb_arithm_ops, + aidge_core.OperatorStats.get_nb_logic_ops, + aidge_core.OperatorStats.get_nb_comp_ops, + aidge_core.OperatorStats.get_nb_flops], filename, title, log_scale) + + def log_nb_mac_ops(self, filename, title=None, log_scale=False): + self._log_callback(aidge_core.OperatorStats.get_nb_mac_ops, filename, title, log_scale) + def _log_callback(self, callback, filename, title=None, log_scale=False): """ Log a statistic given by an OperatorStats callback member function. @@ -68,9 +75,13 @@ class StaticAnalysisExt(aidge_core.StaticAnalysis): attr = {'color': 'red'} if attr: name = (name, attr) - series.append([name, partial(callback, stats)()]) + if isinstance(callback, list): + series.append([name, [partial(cb, stats)() for cb in callback]]) + if title is None: title = str([cb.__name__ for cb in callback]) + else: + series.append([name, partial(callback, stats)()]) + if title is None: title = callback.__name__ - if title is None: title = str(callback) self._log_bar(series, filename, title, log_scale) def _log_bar(self, series, filename, title=None, log_scale=False): @@ -78,7 +89,14 @@ class StaticAnalysisExt(aidge_core.StaticAnalysis): names_only = [item[0] if isinstance(item, tuple) else item for item in names] fig, ax = plt.subplots(figsize=(max(5, len(names)/4), 5)) plt.xlim(-0.5, len(names) - 0.5) - plt.bar(names_only, values) + if isinstance(values[0], list): + series = [list(i) for i in zip(*values)] + bot = np.zeros(len(series[0])) + for i, serie in enumerate(series): + plt.bar(names_only, serie, bottom=bot) + bot += serie + else: + plt.bar(names_only, values) ax.yaxis.minorticks_on() plt.grid(axis='y', which='major', linestyle='--', color='gray') plt.grid(axis='y', which='minor', linestyle=':', color='lightgray') diff --git a/include/aidge/graph/StaticAnalysis.hpp b/include/aidge/graph/StaticAnalysis.hpp index 535d108570e4b1c225f149688535b96fb5697a79..ed2930389ad0de01cab50674e8f0d27abab5c869 100644 --- a/include/aidge/graph/StaticAnalysis.hpp +++ b/include/aidge/graph/StaticAnalysis.hpp @@ -24,6 +24,8 @@ #include "aidge/operator/Conv.hpp" #include "aidge/operator/ConvDepthWise.hpp" #include "aidge/operator/FC.hpp" +#include "aidge/operator/MatMul.hpp" +#include "aidge/operator/ReLU.hpp" #include "aidge/operator/MetaOperator.hpp" namespace Aidge { @@ -42,8 +44,9 @@ public: virtual size_t getNbArithmOps() const { return 2 * getNbMACOps(); }; virtual size_t getNbLogicOps() const { return 0; }; virtual size_t getNbCompOps() const { return 0; }; + virtual size_t getNbFlops() const { return 0; }; + size_t getNbOps() const { return getNbArithmOps() + getNbLogicOps() + getNbCompOps() + getNbFlops(); }; virtual size_t getNbMACOps() const { return 0; }; - virtual size_t getNbFlops() const; virtual ~OperatorStats() = default; protected: @@ -65,8 +68,8 @@ public: size_t getNbArithmOps() const { return accumulate(&OperatorStats::getNbArithmOps); } size_t getNbLogicOps() const { return accumulate(&OperatorStats::getNbLogicOps); } size_t getNbCompOps() const { return accumulate(&OperatorStats::getNbCompOps); } - size_t getNbMACOps() const { return accumulate(&OperatorStats::getNbMACOps); } size_t getNbFlops() const { return accumulate(&OperatorStats::getNbFlops); } + size_t getNbMACOps() const { return accumulate(&OperatorStats::getNbMACOps); } virtual void summary(bool incProducers = false) const; virtual ~StaticAnalysis() = default; @@ -93,8 +96,8 @@ public: size_t getNbArithmOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmOps(); } size_t getNbLogicOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbLogicOps(); } size_t getNbCompOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbCompOps(); } - size_t getNbMACOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbMACOps(); } size_t getNbFlops() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbFlops(); } + size_t getNbMACOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbMACOps(); } }; template <class OP> @@ -108,14 +111,9 @@ public: size_t getNbMACOps() const override { const OP& op_ = dynamic_cast<const OP&>(mOp); - const std::size_t kernelSize = std::accumulate( - op_.kernelDims().cbegin(), - op_.kernelDims().cend(), - std::size_t(1), - std::multiplies<std::size_t>()); - // NCHW - const std::size_t outputSize = op_.getOutput(0)->dims()[2] * op_.getOutput(0)->dims()[3]; - return (kernelSize * outputSize); + const std::size_t weightsSize = op_.getInput(1)->size(); + const std::size_t outputSize = op_.getOutput(0)->dims()[2] * op_.getOutput(0)->dims()[3]; // NCHW + return (weightsSize * outputSize); } }; @@ -133,11 +131,114 @@ public: size_t getNbMACOps() const override { const FC_Op& op_ = dynamic_cast<const FC_Op&>(mOp); - return op_.getInput(1)->size(); + const std::size_t weightsSize = op_.getInput(1)->size(); + return weightsSize; } }; REGISTRAR(OperatorStats, "FC", FCStats::create); + +class MatMulStats : public OperatorStats { +public: + MatMulStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<MatMulStats> create(const Operator& op) { + return std::make_unique<MatMulStats>(op); + } + + size_t getNbMACOps() const override { + const MatMul_Op& op_ = dynamic_cast<const MatMul_Op&>(mOp); + const size_t n = (op_.getInput(0)->dims().size() > 1) + ? op_.getInput(0)->dims().end()[-2] : 1; + const size_t k = op_.getInput(0)->dims().back(); + const size_t m = (op_.getInput(1)->dims().size() > 1) + ? op_.getInput(1)->dims().back() : 1; + const size_t nb = (op_.getInput(0)->dims().size() > 2) + ? std::accumulate(op_.getInput(0)->dims().cbegin(), + op_.getInput(0)->dims().cend() - 2, + 1, + std::multiplies<size_t>()) + : 1; + + return nb * n * m * k; + } +}; + +REGISTRAR(OperatorStats, "MatMul", MatMulStats::create); + +class ReLUStats : public OperatorStats { +public: + ReLUStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<ReLUStats> create(const Operator& op) { + return std::make_unique<ReLUStats>(op); + } + + size_t getNbCompOps() const override { + const ReLU_Op& op_ = dynamic_cast<const ReLU_Op&>(mOp); + return op_.getOutput(0)->size(); + } +}; + +REGISTRAR(OperatorStats, "ReLU", ReLUStats::create); + +class MemOpStats : public OperatorStats { +public: + MemOpStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<MemOpStats> create(const Operator& op) { + return std::make_unique<MemOpStats>(op); + } +}; + +REGISTRAR(OperatorStats, "Reshape", MemOpStats::create); +REGISTRAR(OperatorStats, "Transpose", MemOpStats::create); +REGISTRAR(OperatorStats, "Concat", MemOpStats::create); +REGISTRAR(OperatorStats, "Split", MemOpStats::create); +REGISTRAR(OperatorStats, "Slice", MemOpStats::create); +REGISTRAR(OperatorStats, "Squeeze", MemOpStats::create); +REGISTRAR(OperatorStats, "Unsqueeze", MemOpStats::create); +REGISTRAR(OperatorStats, "Gather", MemOpStats::create); +REGISTRAR(OperatorStats, "Identity", MemOpStats::create); + +class ElemWiseOpStats : public OperatorStats { +public: + ElemWiseOpStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<ElemWiseOpStats> create(const Operator& op) { + return std::make_unique<ElemWiseOpStats>(op); + } + + size_t getNbArithmOps() const override { + const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp); + return op_.getOutput(0)->size(); + } +}; + +REGISTRAR(OperatorStats, "Add", ElemWiseOpStats::create); +REGISTRAR(OperatorStats, "Sub", ElemWiseOpStats::create); +REGISTRAR(OperatorStats, "Mul", ElemWiseOpStats::create); +REGISTRAR(OperatorStats, "Div", ElemWiseOpStats::create); + +class ElemWiseFlopStats : public OperatorStats { +public: + ElemWiseFlopStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<ElemWiseFlopStats> create(const Operator& op) { + return std::make_unique<ElemWiseFlopStats>(op); + } + + size_t getNbFlops() const override { + const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp); + return op_.getOutput(0)->size(); + } +}; + +REGISTRAR(OperatorStats, "Sqrt", ElemWiseFlopStats::create); +REGISTRAR(OperatorStats, "Erf", ElemWiseFlopStats::create); +REGISTRAR(OperatorStats, "Ln", ElemWiseFlopStats::create); +REGISTRAR(OperatorStats, "Sigmoid", ElemWiseFlopStats::create); +REGISTRAR(OperatorStats, "Tanh", ElemWiseFlopStats::create); } #endif /* AIDGE_CORE_GRAPH_STATICANALYSIS_H_ */ diff --git a/src/graph/StaticAnalysis.cpp b/src/graph/StaticAnalysis.cpp index aa12e29c7673e4748d3a895ae469f37aa99ddfc3..6e7cb3bed0cfcb6952752271cf4c5af707cb03f9 100644 --- a/src/graph/StaticAnalysis.cpp +++ b/src/graph/StaticAnalysis.cpp @@ -53,16 +53,6 @@ size_t Aidge::OperatorStats::getParamsSize() const { return paramsSize; } -size_t Aidge::OperatorStats::getNbFlops() const { - const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp); - if (opTensor) { - if (isDataTypeFloatingPoint(opTensor->getOutput(0)->dataType())) { - return getNbArithmOps(); - } - } - return 0; -} - Aidge::StaticAnalysis::StaticAnalysis(std::shared_ptr<GraphView> graph) : mGraph(graph) {