Skip to content
Snippets Groups Projects
Commit 9a95887b authored by Olivier BICHLER's avatar Olivier BICHLER Committed by Axel Farrugia
Browse files

Improved concept

parent 6e37113c
No related branches found
No related tags found
2 merge requests!279v0.4.0,!250[Feat](Exports) Add custom options to exports
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from functools import partial
import aidge_core import aidge_core
class StaticAnalysisExt(aidge_core.StaticAnalysis): class StaticAnalysisExt(aidge_core.StaticAnalysis):
def log_nb_params(self, filename): def log_nb_params(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_params, filename, title, log_scale)
def log_nb_fixed_params(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_fixed_params, filename, title, log_scale)
def log_nb_trainable_params(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_trainable_params, filename, title, log_scale)
def log_params_size(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_params_size, filename, title, log_scale)
def log_nb_arithm_ops(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_arithm_ops, filename, title, log_scale)
def log_nb_logic_ops(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_logic_ops, filename, title, log_scale)
def log_nb_comp_ops(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_comp_ops, filename, title, log_scale)
def log_nb_mac_ops(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_mac_ops, filename, title, log_scale)
def log_nb_flops(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_flops, filename, title, log_scale)
def _log_callback(self, callback, filename, title=None, log_scale=False):
"""
Log a statistic given by an OperatorStats callback member function.
Usage:
stats = StaticAnalysisExt(model)
stats.log_callback(aidge_core.OperatorStats.get_nb_params, "stats.png", "Nb params per operator")
:param func: OperatorStats member function to call.
:param filename: Output graph file name.
:type filename: str
:param title: Title of the graph.
:type title: str
"""
namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})"); namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})");
nodes = self.get_graph().get_ordered_nodes() nodes = self.get_graph().get_ordered_nodes()
names = [] series = []
values = []
for node in nodes: for node in nodes:
if node.type() == "Producer": if node.type() == "Producer":
continue continue
if node.type() in aidge_core.get_keys_OperatorStats(): stats = self.get_op_stats(node)
stats = aidge_core.get_key_value_OperatorStats(node.type()) series.append([namePtrTable[node], partial(callback, stats)()])
else:
stats = aidge_core.OperatorStats(node.get_operator()) if title is None: title = str(callback)
names.append(namePtrTable[node]) self._log_bar(series, filename, title, log_scale)
values.append(stats.get_nb_params())
def _log_bar(self, series, filename, title=None, log_scale=False):
names, values = zip(*series)
fig, ax = plt.subplots(figsize=(max(5, len(names)/4), 5)) fig, ax = plt.subplots(figsize=(max(5, len(names)/4), 5))
plt.xlim(-0.5, len(names) - 0.5) plt.xlim(-0.5, len(names) - 0.5)
plt.bar(names, values) plt.bar(names, values)
...@@ -27,5 +69,6 @@ class StaticAnalysisExt(aidge_core.StaticAnalysis): ...@@ -27,5 +69,6 @@ class StaticAnalysisExt(aidge_core.StaticAnalysis):
plt.grid(axis='y', which='minor', linestyle=':', color='lightgray') plt.grid(axis='y', which='minor', linestyle=':', color='lightgray')
plt.gca().set_axisbelow(True) plt.gca().set_axisbelow(True)
plt.xticks(rotation='vertical') plt.xticks(rotation='vertical')
plt.title('Number of params per operator') if log_scale: plt.yscale('log')
if title is not None: plt.title(title)
plt.savefig(filename, bbox_inches='tight') plt.savefig(filename, bbox_inches='tight')
...@@ -52,6 +52,7 @@ enum class DataType { ...@@ -52,6 +52,7 @@ enum class DataType {
Any Any
}; };
bool isDataTypeFloatingPoint(const DataType& type);
size_t getDataTypeBitWidth(const DataType& type); size_t getDataTypeBitWidth(const DataType& type);
enum class DataFormat { enum class DataFormat {
......
...@@ -22,19 +22,15 @@ ...@@ -22,19 +22,15 @@
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/operator/Conv.hpp" #include "aidge/operator/Conv.hpp"
#include "aidge/operator/ConvDepthWise.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/MetaOperator.hpp"
namespace Aidge { namespace Aidge {
class StaticAnalysis : public std::enable_shared_from_this<StaticAnalysis> { /**
public: * @brief Base class to compute statistics from an Operator
StaticAnalysis(std::shared_ptr<GraphView> graph); *
const std::shared_ptr<GraphView> getGraph() const noexcept { return mGraph; } */
virtual void summary(bool incProducers = false) const;
virtual ~StaticAnalysis() = default;
protected:
const std::shared_ptr<GraphView> mGraph;
};
class OperatorStats : public Registrable<OperatorStats, std::string, std::function<std::shared_ptr<OperatorStats>(const Operator&)>> { class OperatorStats : public Registrable<OperatorStats, std::string, std::function<std::shared_ptr<OperatorStats>(const Operator&)>> {
public: public:
OperatorStats(const Operator& op); OperatorStats(const Operator& op);
...@@ -43,32 +39,105 @@ public: ...@@ -43,32 +39,105 @@ public:
virtual size_t getNbFixedParams() const { return 0; }; virtual size_t getNbFixedParams() const { return 0; };
virtual size_t getNbTrainableParams() const; virtual size_t getNbTrainableParams() const;
virtual size_t getParamsSize() const; virtual size_t getParamsSize() const;
virtual size_t getNbMemAccess() const { return 0; }; virtual size_t getNbArithmOps() const { return 2 * getNbMACOps(); };
virtual size_t getNbArithmOps() const { return 2 * getNbMACOps(); };
virtual size_t getNbLogicOps() const { return 0; }; virtual size_t getNbLogicOps() const { return 0; };
virtual size_t getNbCompOps() const { return 0; }; virtual size_t getNbCompOps() const { return 0; };
virtual size_t getNbMACOps() const { return 0; }; virtual size_t getNbMACOps() const { return 0; };
virtual size_t getNbFlops() const { return 0; }; virtual size_t getNbFlops() const;
virtual ~OperatorStats() = default; virtual ~OperatorStats() = default;
protected: protected:
const Operator &mOp; const Operator &mOp;
}; };
/**
* @brief Base class to compute statistics from a GraphView
*
*/
class StaticAnalysis : public std::enable_shared_from_this<StaticAnalysis> {
public:
StaticAnalysis(std::shared_ptr<GraphView> graph);
const std::shared_ptr<GraphView> getGraph() const noexcept { return mGraph; }
size_t getNbParams() const { return accumulate(&OperatorStats::getNbParams); }
size_t getNbFixedParams() const { return accumulate(&OperatorStats::getNbFixedParams); }
size_t getNbTrainableParams() const { return accumulate(&OperatorStats::getNbTrainableParams); }
size_t getParamsSize() const { return accumulate(&OperatorStats::getParamsSize); }
size_t getNbArithmOps() const { return accumulate(&OperatorStats::getNbArithmOps); }
size_t getNbLogicOps() const { return accumulate(&OperatorStats::getNbLogicOps); }
size_t getNbCompOps() const { return accumulate(&OperatorStats::getNbCompOps); }
size_t getNbMACOps() const { return accumulate(&OperatorStats::getNbMACOps); }
size_t getNbFlops() const { return accumulate(&OperatorStats::getNbFlops); }
virtual void summary(bool incProducers = false) const;
virtual ~StaticAnalysis() = default;
protected:
const std::shared_ptr<GraphView> mGraph;
std::shared_ptr<OperatorStats> getOpStats(std::shared_ptr<Node> node) const;
size_t accumulate(size_t (OperatorStats::*func)() const) const;
};
////////////////////////////////////////////////////////////////////////////////
class MetaOpStats : public OperatorStats {
public:
MetaOpStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<MetaOpStats> create(const Operator& op) {
return std::make_unique<MetaOpStats>(op);
}
size_t getNbFixedParams() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbFixedParams(); }
size_t getNbTrainableParams() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbTrainableParams(); }
size_t getParamsSize() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getParamsSize(); }
size_t getNbArithmOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmOps(); }
size_t getNbLogicOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbLogicOps(); }
size_t getNbCompOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbCompOps(); }
size_t getNbMACOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbMACOps(); }
size_t getNbFlops() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbFlops(); }
};
template <class OP>
class ConvStats : public OperatorStats { class ConvStats : public OperatorStats {
public: public:
ConvStats(const Operator& op) : OperatorStats(op) {} ConvStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<ConvStats> create(const Operator& op) { static std::unique_ptr<ConvStats<OP>> create(const Operator& op) {
return std::make_unique<ConvStats>(op); return std::make_unique<ConvStats<OP>>(op);
}
size_t getNbMACOps() const override {
const OP& op_ = dynamic_cast<const OP&>(mOp);
const std::size_t kernelSize = std::accumulate(
op_.kernelDims().cbegin(),
op_.kernelDims().cend(),
std::size_t(1),
std::multiplies<std::size_t>());
// NCHW
const std::size_t outputSize = op_.getOutput(0)->dims()[2] * op_.getOutput(0)->dims()[3];
return (kernelSize * outputSize);
}
};
// Beware: cannot use Conv_Op<2>::Type as key because static variable initialization order is undefined!
REGISTRAR(OperatorStats, "Conv", ConvStats<Conv_Op<2>>::create);
REGISTRAR(OperatorStats, "ConvDepthWise", ConvStats<ConvDepthWise_Op<2>>::create);
class FCStats : public OperatorStats {
public:
FCStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<FCStats> create(const Operator& op) {
return std::make_unique<FCStats>(op);
} }
size_t getNbMACOps() const { size_t getNbMACOps() const override {
return 0; const FC_Op& op_ = dynamic_cast<const FC_Op&>(mOp);
return op_.getInput(1)->size();
} }
}; };
REGISTRAR(OperatorStats, Conv_Op<2>::Type, ConvStats::create); REGISTRAR(OperatorStats, "FC", FCStats::create);
} }
#endif /* AIDGE_CORE_GRAPH_STATICANALYSIS_H_ */ #endif /* AIDGE_CORE_GRAPH_STATICANALYSIS_H_ */
...@@ -49,14 +49,6 @@ public: ...@@ -49,14 +49,6 @@ public:
); );
} }
size_t getNbMemAccess() const override {
PYBIND11_OVERRIDE(
size_t,
OperatorStats,
getNbMemAccess
);
}
size_t getNbArithmOps() const override { size_t getNbArithmOps() const override {
PYBIND11_OVERRIDE( PYBIND11_OVERRIDE(
size_t, size_t,
...@@ -112,15 +104,20 @@ public: ...@@ -112,15 +104,20 @@ public:
} }
}; };
// See https://pybind11.readthedocs.io/en/stable/advanced/classes.html#binding-protected-member-functions
class StaticAnalysis_Publicist : public StaticAnalysis {
public:
using StaticAnalysis::getOpStats;
};
void init_StaticAnalysis(py::module& m){ void init_StaticAnalysis(py::module& m){
py::class_<OperatorStats, std::shared_ptr<OperatorStats>, pyOperatorStats>(m, "OperatorStats", py::dynamic_attr()) py::class_<OperatorStats, std::shared_ptr<OperatorStats>, pyOperatorStats>(m, "OperatorStats", py::multiple_inheritance(), py::dynamic_attr())
.def(py::init<const Operator&>(), py::arg("op")) .def(py::init<const Operator&>(), py::arg("op"))
.def("get_operator", &OperatorStats::getOperator) .def("get_operator", &OperatorStats::getOperator)
.def("get_nb_params", &OperatorStats::getNbParams) .def("get_nb_params", &OperatorStats::getNbParams)
.def("get_nb_fixed_params", &OperatorStats::getNbFixedParams) .def("get_nb_fixed_params", &OperatorStats::getNbFixedParams)
.def("get_nb_trainable_params", &OperatorStats::getNbTrainableParams) .def("get_nb_trainable_params", &OperatorStats::getNbTrainableParams)
.def("get_params_size", &OperatorStats::getParamsSize) .def("get_params_size", &OperatorStats::getParamsSize)
.def("get_nb_mem_access", &OperatorStats::getNbMemAccess)
.def("get_nb_arithm_ops", &OperatorStats::getNbArithmOps) .def("get_nb_arithm_ops", &OperatorStats::getNbArithmOps)
.def("get_nb_logic_ops", &OperatorStats::getNbLogicOps) .def("get_nb_logic_ops", &OperatorStats::getNbLogicOps)
.def("get_nb_comp_ops", &OperatorStats::getNbCompOps) .def("get_nb_comp_ops", &OperatorStats::getNbCompOps)
...@@ -129,10 +126,20 @@ void init_StaticAnalysis(py::module& m){ ...@@ -129,10 +126,20 @@ void init_StaticAnalysis(py::module& m){
; ;
declare_registrable<OperatorStats>(m, "OperatorStats"); declare_registrable<OperatorStats>(m, "OperatorStats");
py::class_<StaticAnalysis, std::shared_ptr<StaticAnalysis>, pyStaticAnalysis>(m, "StaticAnalysis", py::dynamic_attr()) py::class_<StaticAnalysis, std::shared_ptr<StaticAnalysis>, pyStaticAnalysis>(m, "StaticAnalysis", py::multiple_inheritance(), py::dynamic_attr())
.def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph")) .def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph"))
.def("get_graph", &StaticAnalysis::getGraph) .def("get_graph", &StaticAnalysis::getGraph)
.def("get_nb_params", &StaticAnalysis::getNbParams)
.def("get_nb_fixed_params", &StaticAnalysis::getNbFixedParams)
.def("get_nb_trainable_params", &StaticAnalysis::getNbTrainableParams)
.def("get_params_size", &StaticAnalysis::getParamsSize)
.def("get_nb_arithm_ops", &StaticAnalysis::getNbArithmOps)
.def("get_nb_logic_ops", &StaticAnalysis::getNbLogicOps)
.def("get_nb_comp_ops", &StaticAnalysis::getNbCompOps)
.def("get_nb_mac_ops", &StaticAnalysis::getNbMACOps)
.def("get_nb_flops", &StaticAnalysis::getNbFlops)
.def("summary", &StaticAnalysis::summary, py::arg("inc_producers") = false) .def("summary", &StaticAnalysis::summary, py::arg("inc_producers") = false)
.def("get_op_stats", &StaticAnalysis_Publicist::getOpStats, py::arg("node"))
; ;
} }
} }
...@@ -11,6 +11,17 @@ ...@@ -11,6 +11,17 @@
#include "aidge/data/Data.hpp" #include "aidge/data/Data.hpp"
bool Aidge::isDataTypeFloatingPoint(const DataType& type) {
switch (type) {
case DataType::Float64:
case DataType::Float32:
case DataType::Float16:
case DataType::BFloat16: return true;
default: return false;
}
return false;
}
size_t Aidge::getDataTypeBitWidth(const DataType& type) { size_t Aidge::getDataTypeBitWidth(const DataType& type) {
switch (type) { switch (type) {
case DataType::Float64: return 64; case DataType::Float64: return 64;
......
...@@ -53,6 +53,16 @@ size_t Aidge::OperatorStats::getParamsSize() const { ...@@ -53,6 +53,16 @@ size_t Aidge::OperatorStats::getParamsSize() const {
return paramsSize; return paramsSize;
} }
size_t Aidge::OperatorStats::getNbFlops() const {
const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp);
if (opTensor) {
if (isDataTypeFloatingPoint(opTensor->getOutput(0)->dataType())) {
return getNbArithmOps();
}
}
return 0;
}
Aidge::StaticAnalysis::StaticAnalysis(std::shared_ptr<GraphView> graph) Aidge::StaticAnalysis::StaticAnalysis(std::shared_ptr<GraphView> graph)
: mGraph(graph) : mGraph(graph)
{ {
...@@ -89,9 +99,7 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const { ...@@ -89,9 +99,7 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const {
} }
} }
const auto stats = (Registrar<OperatorStats>::exists(node->type())) const auto stats = getOpStats(node);
? Registrar<OperatorStats>::create(node->type())(*(node->getOperator()))
: std::make_shared<OperatorStats>(*(node->getOperator()));
nbTrainableParams += stats->getNbTrainableParams(); nbTrainableParams += stats->getNbTrainableParams();
nbFixedParams += stats->getNbFixedParams(); nbFixedParams += stats->getNbFixedParams();
paramsSize += stats->getParamsSize(); paramsSize += stats->getParamsSize();
...@@ -120,3 +128,21 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const { ...@@ -120,3 +128,21 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const {
fmt::println("Estimated Total Size (MB): {}", (inputSize + fwdBwdSize + paramsSize) / 8.0 / 1024 / 1024); fmt::println("Estimated Total Size (MB): {}", (inputSize + fwdBwdSize + paramsSize) / 8.0 / 1024 / 1024);
fmt::println("--------------------------------------------------------------------------------"); fmt::println("--------------------------------------------------------------------------------");
} }
std::shared_ptr<Aidge::OperatorStats> Aidge::StaticAnalysis::getOpStats(std::shared_ptr<Node> node) const {
return (Registrar<OperatorStats>::exists(node->type()))
? Registrar<OperatorStats>::create(node->type())(*(node->getOperator()))
: (node->getOperator()->isAtomic())
? std::make_shared<OperatorStats>(*(node->getOperator()))
: std::make_shared<MetaOpStats>(*(node->getOperator()));
}
size_t Aidge::StaticAnalysis::accumulate(size_t (OperatorStats::*func)() const) const {
return std::accumulate(
mGraph->getNodes().cbegin(),
mGraph->getNodes().cend(),
std::size_t(0),
[this, func](const size_t& lhs, const std::shared_ptr<Node>& rhs) {
return lhs + (this->getOpStats(rhs).get()->*func)();
});
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment