Skip to content
Snippets Groups Projects
Commit aa939b56 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Improved concept

parent eaf748a3
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!219Initial version of hybrid C++/Python static analysis
Pipeline #56713 passed
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from functools import partial
import aidge_core import aidge_core
class StaticAnalysisExt(aidge_core.StaticAnalysis): class StaticAnalysisExt(aidge_core.StaticAnalysis):
def log_nb_params(self, filename): def log_nb_params(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_params, filename, title, log_scale)
def log_nb_fixed_params(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_fixed_params, filename, title, log_scale)
def log_nb_trainable_params(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_trainable_params, filename, title, log_scale)
def log_params_size(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_params_size, filename, title, log_scale)
def log_nb_arithm_ops(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_arithm_ops, filename, title, log_scale)
def log_nb_logic_ops(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_logic_ops, filename, title, log_scale)
def log_nb_comp_ops(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_comp_ops, filename, title, log_scale)
def log_nb_mac_ops(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_mac_ops, filename, title, log_scale)
def log_nb_flops(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_flops, filename, title, log_scale)
def _log_callback(self, callback, filename, title=None, log_scale=False):
"""
Log a statistic given by an OperatorStats callback member function.
Usage:
stats = StaticAnalysisExt(model)
stats.log_callback(aidge_core.OperatorStats.get_nb_params, "stats.png", "Nb params per operator")
:param func: OperatorStats member function to call.
:param filename: Output graph file name.
:type filename: str
:param title: Title of the graph.
:type title: str
"""
namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})"); namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})");
nodes = self.get_graph().get_ordered_nodes() nodes = self.get_graph().get_ordered_nodes()
names = [] series = []
values = []
for node in nodes: for node in nodes:
if node.type() == "Producer": if node.type() == "Producer":
continue continue
if node.type() in aidge_core.get_keys_OperatorStats(): stats = self.get_op_stats(node)
stats = aidge_core.get_key_value_OperatorStats(node.type()) series.append([namePtrTable[node], partial(callback, stats)()])
else:
stats = aidge_core.OperatorStats(node.get_operator()) if title is None: title = str(callback)
names.append(namePtrTable[node]) self._log_bar(series, filename, title, log_scale)
values.append(stats.get_nb_params())
def _log_bar(self, series, filename, title=None, log_scale=False):
names, values = zip(*series)
fig, ax = plt.subplots(figsize=(max(5, len(names)/4), 5)) fig, ax = plt.subplots(figsize=(max(5, len(names)/4), 5))
plt.xlim(-0.5, len(names) - 0.5) plt.xlim(-0.5, len(names) - 0.5)
plt.bar(names, values) plt.bar(names, values)
...@@ -27,5 +69,6 @@ class StaticAnalysisExt(aidge_core.StaticAnalysis): ...@@ -27,5 +69,6 @@ class StaticAnalysisExt(aidge_core.StaticAnalysis):
plt.grid(axis='y', which='minor', linestyle=':', color='lightgray') plt.grid(axis='y', which='minor', linestyle=':', color='lightgray')
plt.gca().set_axisbelow(True) plt.gca().set_axisbelow(True)
plt.xticks(rotation='vertical') plt.xticks(rotation='vertical')
plt.title('Number of params per operator') if log_scale: plt.yscale('log')
if title is not None: plt.title(title)
plt.savefig(filename, bbox_inches='tight') plt.savefig(filename, bbox_inches='tight')
...@@ -119,30 +119,17 @@ private: ...@@ -119,30 +119,17 @@ private:
template <typename T> template <typename T>
const std::string TensorImpl_cpu<T>::Backend = "cpu"; const std::string TensorImpl_cpu<T>::Backend = "cpu";
namespace { REGISTRAR(Tensor, {"cpu", DataType::Float64}, Aidge::TensorImpl_cpu<double>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Float64( REGISTRAR(Tensor, {"cpu", DataType::Float32}, Aidge::TensorImpl_cpu<float>::create);
{"cpu", DataType::Float64}, Aidge::TensorImpl_cpu<double>::create); REGISTRAR(Tensor, {"cpu", DataType::Float16}, Aidge::TensorImpl_cpu<half_float::half>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Float32( REGISTRAR(Tensor, {"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<int64_t>::create);
{"cpu", DataType::Float32}, Aidge::TensorImpl_cpu<float>::create); REGISTRAR(Tensor, {"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int32_t>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Float16( REGISTRAR(Tensor, {"cpu", DataType::Int16}, Aidge::TensorImpl_cpu<int16_t>::create);
{"cpu", DataType::Float16}, Aidge::TensorImpl_cpu<half_float::half>::create); REGISTRAR(Tensor, {"cpu", DataType::Int8}, Aidge::TensorImpl_cpu<int8_t>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Int64( REGISTRAR(Tensor, {"cpu", DataType::UInt64}, Aidge::TensorImpl_cpu<uint64_t>::create);
{"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<int64_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt32}, Aidge::TensorImpl_cpu<uint32_t>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Int32( REGISTRAR(Tensor, {"cpu", DataType::UInt16}, Aidge::TensorImpl_cpu<uint16_t>::create);
{"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int32_t>::create); REGISTRAR(Tensor, {"cpu", DataType::UInt8}, Aidge::TensorImpl_cpu<uint8_t>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Int16(
{"cpu", DataType::Int16}, Aidge::TensorImpl_cpu<int16_t>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_Int8(
{"cpu", DataType::Int8}, Aidge::TensorImpl_cpu<int8_t>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_UInt64(
{"cpu", DataType::UInt64}, Aidge::TensorImpl_cpu<uint64_t>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_UInt32(
{"cpu", DataType::UInt32}, Aidge::TensorImpl_cpu<uint32_t>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_UInt16(
{"cpu", DataType::UInt16}, Aidge::TensorImpl_cpu<uint16_t>::create);
static Registrar<Tensor> registrarTensorImpl_cpu_UInt8(
{"cpu", DataType::UInt8}, Aidge::TensorImpl_cpu<uint8_t>::create);
} // namespace
} // namespace Aidge } // namespace Aidge
#endif /* AIDGE_CPU_DATA_TENSORIMPL_H_ */ #endif /* AIDGE_CPU_DATA_TENSORIMPL_H_ */
...@@ -52,6 +52,7 @@ enum class DataType { ...@@ -52,6 +52,7 @@ enum class DataType {
Any Any
}; };
bool isDataTypeFloatingPoint(const DataType& type);
size_t getDataTypeBitWidth(const DataType& type); size_t getDataTypeBitWidth(const DataType& type);
enum class DataFormat { enum class DataFormat {
......
...@@ -22,19 +22,15 @@ ...@@ -22,19 +22,15 @@
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/operator/Conv.hpp" #include "aidge/operator/Conv.hpp"
#include "aidge/operator/ConvDepthWise.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/MetaOperator.hpp"
namespace Aidge { namespace Aidge {
class StaticAnalysis : public std::enable_shared_from_this<StaticAnalysis> { /**
public: * @brief Base class to compute statistics from an Operator
StaticAnalysis(std::shared_ptr<GraphView> graph); *
const std::shared_ptr<GraphView> getGraph() const noexcept { return mGraph; } */
virtual void summary(bool incProducers = false) const;
virtual ~StaticAnalysis() = default;
protected:
const std::shared_ptr<GraphView> mGraph;
};
class OperatorStats : public Registrable<OperatorStats, std::string, std::function<std::shared_ptr<OperatorStats>(const Operator&)>> { class OperatorStats : public Registrable<OperatorStats, std::string, std::function<std::shared_ptr<OperatorStats>(const Operator&)>> {
public: public:
OperatorStats(const Operator& op); OperatorStats(const Operator& op);
...@@ -43,32 +39,105 @@ public: ...@@ -43,32 +39,105 @@ public:
virtual size_t getNbFixedParams() const { return 0; }; virtual size_t getNbFixedParams() const { return 0; };
virtual size_t getNbTrainableParams() const; virtual size_t getNbTrainableParams() const;
virtual size_t getParamsSize() const; virtual size_t getParamsSize() const;
virtual size_t getNbMemAccess() const { return 0; }; virtual size_t getNbArithmOps() const { return 2 * getNbMACOps(); };
virtual size_t getNbArithmOps() const { return 2 * getNbMACOps(); };
virtual size_t getNbLogicOps() const { return 0; }; virtual size_t getNbLogicOps() const { return 0; };
virtual size_t getNbCompOps() const { return 0; }; virtual size_t getNbCompOps() const { return 0; };
virtual size_t getNbMACOps() const { return 0; }; virtual size_t getNbMACOps() const { return 0; };
virtual size_t getNbFlops() const { return 0; }; virtual size_t getNbFlops() const;
virtual ~OperatorStats() = default; virtual ~OperatorStats() = default;
protected: protected:
const Operator &mOp; const Operator &mOp;
}; };
/**
* @brief Base class to compute statistics from a GraphView
*
*/
class StaticAnalysis : public std::enable_shared_from_this<StaticAnalysis> {
public:
StaticAnalysis(std::shared_ptr<GraphView> graph);
const std::shared_ptr<GraphView> getGraph() const noexcept { return mGraph; }
size_t getNbParams() const { return accumulate(&OperatorStats::getNbParams); }
size_t getNbFixedParams() const { return accumulate(&OperatorStats::getNbFixedParams); }
size_t getNbTrainableParams() const { return accumulate(&OperatorStats::getNbTrainableParams); }
size_t getParamsSize() const { return accumulate(&OperatorStats::getParamsSize); }
size_t getNbArithmOps() const { return accumulate(&OperatorStats::getNbArithmOps); }
size_t getNbLogicOps() const { return accumulate(&OperatorStats::getNbLogicOps); }
size_t getNbCompOps() const { return accumulate(&OperatorStats::getNbCompOps); }
size_t getNbMACOps() const { return accumulate(&OperatorStats::getNbMACOps); }
size_t getNbFlops() const { return accumulate(&OperatorStats::getNbFlops); }
virtual void summary(bool incProducers = false) const;
virtual ~StaticAnalysis() = default;
protected:
const std::shared_ptr<GraphView> mGraph;
std::shared_ptr<OperatorStats> getOpStats(std::shared_ptr<Node> node) const;
size_t accumulate(size_t (OperatorStats::*func)() const) const;
};
////////////////////////////////////////////////////////////////////////////////
class MetaOpStats : public OperatorStats {
public:
MetaOpStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<MetaOpStats> create(const Operator& op) {
return std::make_unique<MetaOpStats>(op);
}
size_t getNbFixedParams() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbFixedParams(); }
size_t getNbTrainableParams() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbTrainableParams(); }
size_t getParamsSize() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getParamsSize(); }
size_t getNbArithmOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmOps(); }
size_t getNbLogicOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbLogicOps(); }
size_t getNbCompOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbCompOps(); }
size_t getNbMACOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbMACOps(); }
size_t getNbFlops() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbFlops(); }
};
template <class OP>
class ConvStats : public OperatorStats { class ConvStats : public OperatorStats {
public: public:
ConvStats(const Operator& op) : OperatorStats(op) {} ConvStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<ConvStats> create(const Operator& op) { static std::unique_ptr<ConvStats<OP>> create(const Operator& op) {
return std::make_unique<ConvStats>(op); return std::make_unique<ConvStats<OP>>(op);
}
size_t getNbMACOps() const override {
const OP& op_ = dynamic_cast<const OP&>(mOp);
const std::size_t kernelSize = std::accumulate(
op_.kernelDims().cbegin(),
op_.kernelDims().cend(),
std::size_t(1),
std::multiplies<std::size_t>());
// NCHW
const std::size_t outputSize = op_.getOutput(0)->dims()[2] * op_.getOutput(0)->dims()[3];
return (kernelSize * outputSize);
}
};
// Beware: cannot use Conv_Op<2>::Type as key because static variable initialization order is undefined!
REGISTRAR(OperatorStats, "Conv", ConvStats<Conv_Op<2>>::create);
REGISTRAR(OperatorStats, "ConvDepthWise", ConvStats<ConvDepthWise_Op<2>>::create);
class FCStats : public OperatorStats {
public:
FCStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<FCStats> create(const Operator& op) {
return std::make_unique<FCStats>(op);
} }
size_t getNbMACOps() const { size_t getNbMACOps() const override {
return 0; const FC_Op& op_ = dynamic_cast<const FC_Op&>(mOp);
return op_.getInput(1)->size();
} }
}; };
REGISTRAR(OperatorStats, Conv_Op<2>::Type, ConvStats::create); REGISTRAR(OperatorStats, "FC", FCStats::create);
} }
#endif /* AIDGE_CORE_GRAPH_STATICANALYSIS_H_ */ #endif /* AIDGE_CORE_GRAPH_STATICANALYSIS_H_ */
...@@ -49,14 +49,6 @@ public: ...@@ -49,14 +49,6 @@ public:
); );
} }
size_t getNbMemAccess() const override {
PYBIND11_OVERRIDE(
size_t,
OperatorStats,
getNbMemAccess
);
}
size_t getNbArithmOps() const override { size_t getNbArithmOps() const override {
PYBIND11_OVERRIDE( PYBIND11_OVERRIDE(
size_t, size_t,
...@@ -112,15 +104,20 @@ public: ...@@ -112,15 +104,20 @@ public:
} }
}; };
// See https://pybind11.readthedocs.io/en/stable/advanced/classes.html#binding-protected-member-functions
class StaticAnalysis_Publicist : public StaticAnalysis {
public:
using StaticAnalysis::getOpStats;
};
void init_StaticAnalysis(py::module& m){ void init_StaticAnalysis(py::module& m){
py::class_<OperatorStats, std::shared_ptr<OperatorStats>, pyOperatorStats>(m, "OperatorStats", py::dynamic_attr()) py::class_<OperatorStats, std::shared_ptr<OperatorStats>, pyOperatorStats>(m, "OperatorStats", py::multiple_inheritance(), py::dynamic_attr())
.def(py::init<const Operator&>(), py::arg("op")) .def(py::init<const Operator&>(), py::arg("op"))
.def("get_operator", &OperatorStats::getOperator) .def("get_operator", &OperatorStats::getOperator)
.def("get_nb_params", &OperatorStats::getNbParams) .def("get_nb_params", &OperatorStats::getNbParams)
.def("get_nb_fixed_params", &OperatorStats::getNbFixedParams) .def("get_nb_fixed_params", &OperatorStats::getNbFixedParams)
.def("get_nb_trainable_params", &OperatorStats::getNbTrainableParams) .def("get_nb_trainable_params", &OperatorStats::getNbTrainableParams)
.def("get_params_size", &OperatorStats::getParamsSize) .def("get_params_size", &OperatorStats::getParamsSize)
.def("get_nb_mem_access", &OperatorStats::getNbMemAccess)
.def("get_nb_arithm_ops", &OperatorStats::getNbArithmOps) .def("get_nb_arithm_ops", &OperatorStats::getNbArithmOps)
.def("get_nb_logic_ops", &OperatorStats::getNbLogicOps) .def("get_nb_logic_ops", &OperatorStats::getNbLogicOps)
.def("get_nb_comp_ops", &OperatorStats::getNbCompOps) .def("get_nb_comp_ops", &OperatorStats::getNbCompOps)
...@@ -129,10 +126,20 @@ void init_StaticAnalysis(py::module& m){ ...@@ -129,10 +126,20 @@ void init_StaticAnalysis(py::module& m){
; ;
declare_registrable<OperatorStats>(m, "OperatorStats"); declare_registrable<OperatorStats>(m, "OperatorStats");
py::class_<StaticAnalysis, std::shared_ptr<StaticAnalysis>, pyStaticAnalysis>(m, "StaticAnalysis", py::dynamic_attr()) py::class_<StaticAnalysis, std::shared_ptr<StaticAnalysis>, pyStaticAnalysis>(m, "StaticAnalysis", py::multiple_inheritance(), py::dynamic_attr())
.def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph")) .def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph"))
.def("get_graph", &StaticAnalysis::getGraph) .def("get_graph", &StaticAnalysis::getGraph)
.def("get_nb_params", &StaticAnalysis::getNbParams)
.def("get_nb_fixed_params", &StaticAnalysis::getNbFixedParams)
.def("get_nb_trainable_params", &StaticAnalysis::getNbTrainableParams)
.def("get_params_size", &StaticAnalysis::getParamsSize)
.def("get_nb_arithm_ops", &StaticAnalysis::getNbArithmOps)
.def("get_nb_logic_ops", &StaticAnalysis::getNbLogicOps)
.def("get_nb_comp_ops", &StaticAnalysis::getNbCompOps)
.def("get_nb_mac_ops", &StaticAnalysis::getNbMACOps)
.def("get_nb_flops", &StaticAnalysis::getNbFlops)
.def("summary", &StaticAnalysis::summary, py::arg("inc_producers") = false) .def("summary", &StaticAnalysis::summary, py::arg("inc_producers") = false)
.def("get_op_stats", &StaticAnalysis_Publicist::getOpStats, py::arg("node"))
; ;
} }
} }
...@@ -11,6 +11,17 @@ ...@@ -11,6 +11,17 @@
#include "aidge/data/Data.hpp" #include "aidge/data/Data.hpp"
bool Aidge::isDataTypeFloatingPoint(const DataType& type) {
switch (type) {
case DataType::Float64:
case DataType::Float32:
case DataType::Float16:
case DataType::BFloat16: return true;
default: return false;
}
return false;
}
size_t Aidge::getDataTypeBitWidth(const DataType& type) { size_t Aidge::getDataTypeBitWidth(const DataType& type) {
switch (type) { switch (type) {
case DataType::Float64: return 64; case DataType::Float64: return 64;
......
...@@ -53,6 +53,16 @@ size_t Aidge::OperatorStats::getParamsSize() const { ...@@ -53,6 +53,16 @@ size_t Aidge::OperatorStats::getParamsSize() const {
return paramsSize; return paramsSize;
} }
size_t Aidge::OperatorStats::getNbFlops() const {
const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp);
if (opTensor) {
if (isDataTypeFloatingPoint(opTensor->getOutput(0)->dataType())) {
return getNbArithmOps();
}
}
return 0;
}
Aidge::StaticAnalysis::StaticAnalysis(std::shared_ptr<GraphView> graph) Aidge::StaticAnalysis::StaticAnalysis(std::shared_ptr<GraphView> graph)
: mGraph(graph) : mGraph(graph)
{ {
...@@ -89,9 +99,7 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const { ...@@ -89,9 +99,7 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const {
} }
} }
const auto stats = (Registrar<OperatorStats>::exists(node->type())) const auto stats = getOpStats(node);
? Registrar<OperatorStats>::create(node->type())(*(node->getOperator()))
: std::make_shared<OperatorStats>(*(node->getOperator()));
nbTrainableParams += stats->getNbTrainableParams(); nbTrainableParams += stats->getNbTrainableParams();
nbFixedParams += stats->getNbFixedParams(); nbFixedParams += stats->getNbFixedParams();
paramsSize += stats->getParamsSize(); paramsSize += stats->getParamsSize();
...@@ -120,3 +128,21 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const { ...@@ -120,3 +128,21 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const {
fmt::println("Estimated Total Size (MB): {}", (inputSize + fwdBwdSize + paramsSize) / 8.0 / 1024 / 1024); fmt::println("Estimated Total Size (MB): {}", (inputSize + fwdBwdSize + paramsSize) / 8.0 / 1024 / 1024);
fmt::println("--------------------------------------------------------------------------------"); fmt::println("--------------------------------------------------------------------------------");
} }
std::shared_ptr<Aidge::OperatorStats> Aidge::StaticAnalysis::getOpStats(std::shared_ptr<Node> node) const {
return (Registrar<OperatorStats>::exists(node->type()))
? Registrar<OperatorStats>::create(node->type())(*(node->getOperator()))
: (node->getOperator()->isAtomic())
? std::make_shared<OperatorStats>(*(node->getOperator()))
: std::make_shared<MetaOpStats>(*(node->getOperator()));
}
size_t Aidge::StaticAnalysis::accumulate(size_t (OperatorStats::*func)() const) const {
return std::accumulate(
mGraph->getNodes().cbegin(),
mGraph->getNodes().cend(),
std::size_t(0),
[this, func](const size_t& lhs, const std::shared_ptr<Node>& rhs) {
return lhs + (this->getOpStats(rhs).get()->*func)();
});
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment