From aa939b56b45d8e01618bcd3045d45dcbf2eeea33 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 11 Oct 2024 17:28:30 +0200 Subject: [PATCH] Improved concept --- aidge_core/static_analysis.py | 63 ++++++++-- include/aidge/backend/cpu/data/TensorImpl.hpp | 35 ++---- include/aidge/data/Data.hpp | 1 + include/aidge/graph/StaticAnalysis.hpp | 109 ++++++++++++++---- .../graph/pybind_StaticAnalysis.cpp | 29 +++-- src/data/Data.cpp | 11 ++ src/graph/StaticAnalysis.cpp | 32 ++++- 7 files changed, 212 insertions(+), 68 deletions(-) diff --git a/aidge_core/static_analysis.py b/aidge_core/static_analysis.py index ea88af5aa..ddc494efc 100644 --- a/aidge_core/static_analysis.py +++ b/aidge_core/static_analysis.py @@ -1,24 +1,66 @@ import matplotlib.pyplot as plt +from functools import partial import aidge_core class StaticAnalysisExt(aidge_core.StaticAnalysis): - def log_nb_params(self, filename): + def log_nb_params(self, filename, title=None, log_scale=False): + self._log_callback(aidge_core.OperatorStats.get_nb_params, filename, title, log_scale) + + def log_nb_fixed_params(self, filename, title=None, log_scale=False): + self._log_callback(aidge_core.OperatorStats.get_nb_fixed_params, filename, title, log_scale) + + def log_nb_trainable_params(self, filename, title=None, log_scale=False): + self._log_callback(aidge_core.OperatorStats.get_nb_trainable_params, filename, title, log_scale) + + def log_params_size(self, filename, title=None, log_scale=False): + self._log_callback(aidge_core.OperatorStats.get_params_size, filename, title, log_scale) + + def log_nb_arithm_ops(self, filename, title=None, log_scale=False): + self._log_callback(aidge_core.OperatorStats.get_nb_arithm_ops, filename, title, log_scale) + + def log_nb_logic_ops(self, filename, title=None, log_scale=False): + self._log_callback(aidge_core.OperatorStats.get_nb_logic_ops, filename, title, log_scale) + + def log_nb_comp_ops(self, filename, title=None, log_scale=False): + self._log_callback(aidge_core.OperatorStats.get_nb_comp_ops, filename, title, log_scale) + + def log_nb_mac_ops(self, filename, title=None, log_scale=False): + self._log_callback(aidge_core.OperatorStats.get_nb_mac_ops, filename, title, log_scale) + + def log_nb_flops(self, filename, title=None, log_scale=False): + self._log_callback(aidge_core.OperatorStats.get_nb_flops, filename, title, log_scale) + + def _log_callback(self, callback, filename, title=None, log_scale=False): + """ + Log a statistic given by an OperatorStats callback member function. + Usage: + + stats = StaticAnalysisExt(model) + stats.log_callback(aidge_core.OperatorStats.get_nb_params, "stats.png", "Nb params per operator") + + :param func: OperatorStats member function to call. + :param filename: Output graph file name. + :type filename: str + :param title: Title of the graph. + :type title: str + """ + namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})"); nodes = self.get_graph().get_ordered_nodes() - names = [] - values = [] + series = [] for node in nodes: if node.type() == "Producer": continue - if node.type() in aidge_core.get_keys_OperatorStats(): - stats = aidge_core.get_key_value_OperatorStats(node.type()) - else: - stats = aidge_core.OperatorStats(node.get_operator()) - names.append(namePtrTable[node]) - values.append(stats.get_nb_params()) + stats = self.get_op_stats(node) + series.append([namePtrTable[node], partial(callback, stats)()]) + + if title is None: title = str(callback) + self._log_bar(series, filename, title, log_scale) + def _log_bar(self, series, filename, title=None, log_scale=False): + names, values = zip(*series) fig, ax = plt.subplots(figsize=(max(5, len(names)/4), 5)) plt.xlim(-0.5, len(names) - 0.5) plt.bar(names, values) @@ -27,5 +69,6 @@ class StaticAnalysisExt(aidge_core.StaticAnalysis): plt.grid(axis='y', which='minor', linestyle=':', color='lightgray') plt.gca().set_axisbelow(True) plt.xticks(rotation='vertical') - plt.title('Number of params per operator') + if log_scale: plt.yscale('log') + if title is not None: plt.title(title) plt.savefig(filename, bbox_inches='tight') diff --git a/include/aidge/backend/cpu/data/TensorImpl.hpp b/include/aidge/backend/cpu/data/TensorImpl.hpp index 6454ed233..fd2a0b3f4 100644 --- a/include/aidge/backend/cpu/data/TensorImpl.hpp +++ b/include/aidge/backend/cpu/data/TensorImpl.hpp @@ -119,30 +119,17 @@ private: template <typename T> const std::string TensorImpl_cpu<T>::Backend = "cpu"; -namespace { -static Registrar<Tensor> registrarTensorImpl_cpu_Float64( - {"cpu", DataType::Float64}, Aidge::TensorImpl_cpu<double>::create); -static Registrar<Tensor> registrarTensorImpl_cpu_Float32( - {"cpu", DataType::Float32}, Aidge::TensorImpl_cpu<float>::create); -static Registrar<Tensor> registrarTensorImpl_cpu_Float16( - {"cpu", DataType::Float16}, Aidge::TensorImpl_cpu<half_float::half>::create); -static Registrar<Tensor> registrarTensorImpl_cpu_Int64( - {"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<int64_t>::create); -static Registrar<Tensor> registrarTensorImpl_cpu_Int32( - {"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int32_t>::create); -static Registrar<Tensor> registrarTensorImpl_cpu_Int16( - {"cpu", DataType::Int16}, Aidge::TensorImpl_cpu<int16_t>::create); -static Registrar<Tensor> registrarTensorImpl_cpu_Int8( - {"cpu", DataType::Int8}, Aidge::TensorImpl_cpu<int8_t>::create); -static Registrar<Tensor> registrarTensorImpl_cpu_UInt64( - {"cpu", DataType::UInt64}, Aidge::TensorImpl_cpu<uint64_t>::create); -static Registrar<Tensor> registrarTensorImpl_cpu_UInt32( - {"cpu", DataType::UInt32}, Aidge::TensorImpl_cpu<uint32_t>::create); -static Registrar<Tensor> registrarTensorImpl_cpu_UInt16( - {"cpu", DataType::UInt16}, Aidge::TensorImpl_cpu<uint16_t>::create); -static Registrar<Tensor> registrarTensorImpl_cpu_UInt8( - {"cpu", DataType::UInt8}, Aidge::TensorImpl_cpu<uint8_t>::create); -} // namespace +REGISTRAR(Tensor, {"cpu", DataType::Float64}, Aidge::TensorImpl_cpu<double>::create); +REGISTRAR(Tensor, {"cpu", DataType::Float32}, Aidge::TensorImpl_cpu<float>::create); +REGISTRAR(Tensor, {"cpu", DataType::Float16}, Aidge::TensorImpl_cpu<half_float::half>::create); +REGISTRAR(Tensor, {"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<int64_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int32_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Int16}, Aidge::TensorImpl_cpu<int16_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Int8}, Aidge::TensorImpl_cpu<int8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::UInt64}, Aidge::TensorImpl_cpu<uint64_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::UInt32}, Aidge::TensorImpl_cpu<uint32_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::UInt16}, Aidge::TensorImpl_cpu<uint16_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::UInt8}, Aidge::TensorImpl_cpu<uint8_t>::create); } // namespace Aidge #endif /* AIDGE_CPU_DATA_TENSORIMPL_H_ */ diff --git a/include/aidge/data/Data.hpp b/include/aidge/data/Data.hpp index df52b30f8..6f8771942 100644 --- a/include/aidge/data/Data.hpp +++ b/include/aidge/data/Data.hpp @@ -52,6 +52,7 @@ enum class DataType { Any }; +bool isDataTypeFloatingPoint(const DataType& type); size_t getDataTypeBitWidth(const DataType& type); enum class DataFormat { diff --git a/include/aidge/graph/StaticAnalysis.hpp b/include/aidge/graph/StaticAnalysis.hpp index 091acd0aa..535d10857 100644 --- a/include/aidge/graph/StaticAnalysis.hpp +++ b/include/aidge/graph/StaticAnalysis.hpp @@ -22,19 +22,15 @@ #include "aidge/operator/Producer.hpp" #include "aidge/operator/Conv.hpp" +#include "aidge/operator/ConvDepthWise.hpp" +#include "aidge/operator/FC.hpp" +#include "aidge/operator/MetaOperator.hpp" namespace Aidge { -class StaticAnalysis : public std::enable_shared_from_this<StaticAnalysis> { -public: - StaticAnalysis(std::shared_ptr<GraphView> graph); - const std::shared_ptr<GraphView> getGraph() const noexcept { return mGraph; } - virtual void summary(bool incProducers = false) const; - virtual ~StaticAnalysis() = default; - -protected: - const std::shared_ptr<GraphView> mGraph; -}; - +/** + * @brief Base class to compute statistics from an Operator + * + */ class OperatorStats : public Registrable<OperatorStats, std::string, std::function<std::shared_ptr<OperatorStats>(const Operator&)>> { public: OperatorStats(const Operator& op); @@ -43,32 +39,105 @@ public: virtual size_t getNbFixedParams() const { return 0; }; virtual size_t getNbTrainableParams() const; virtual size_t getParamsSize() const; - virtual size_t getNbMemAccess() const { return 0; }; - virtual size_t getNbArithmOps() const { return 2 * getNbMACOps(); }; + virtual size_t getNbArithmOps() const { return 2 * getNbMACOps(); }; virtual size_t getNbLogicOps() const { return 0; }; virtual size_t getNbCompOps() const { return 0; }; - virtual size_t getNbMACOps() const { return 0; }; - virtual size_t getNbFlops() const { return 0; }; + virtual size_t getNbMACOps() const { return 0; }; + virtual size_t getNbFlops() const; virtual ~OperatorStats() = default; protected: const Operator &mOp; }; +/** + * @brief Base class to compute statistics from a GraphView + * + */ +class StaticAnalysis : public std::enable_shared_from_this<StaticAnalysis> { +public: + StaticAnalysis(std::shared_ptr<GraphView> graph); + const std::shared_ptr<GraphView> getGraph() const noexcept { return mGraph; } + size_t getNbParams() const { return accumulate(&OperatorStats::getNbParams); } + size_t getNbFixedParams() const { return accumulate(&OperatorStats::getNbFixedParams); } + size_t getNbTrainableParams() const { return accumulate(&OperatorStats::getNbTrainableParams); } + size_t getParamsSize() const { return accumulate(&OperatorStats::getParamsSize); } + size_t getNbArithmOps() const { return accumulate(&OperatorStats::getNbArithmOps); } + size_t getNbLogicOps() const { return accumulate(&OperatorStats::getNbLogicOps); } + size_t getNbCompOps() const { return accumulate(&OperatorStats::getNbCompOps); } + size_t getNbMACOps() const { return accumulate(&OperatorStats::getNbMACOps); } + size_t getNbFlops() const { return accumulate(&OperatorStats::getNbFlops); } + virtual void summary(bool incProducers = false) const; + virtual ~StaticAnalysis() = default; + +protected: + const std::shared_ptr<GraphView> mGraph; + + std::shared_ptr<OperatorStats> getOpStats(std::shared_ptr<Node> node) const; + size_t accumulate(size_t (OperatorStats::*func)() const) const; +}; + +//////////////////////////////////////////////////////////////////////////////// + +class MetaOpStats : public OperatorStats { +public: + MetaOpStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<MetaOpStats> create(const Operator& op) { + return std::make_unique<MetaOpStats>(op); + } + + size_t getNbFixedParams() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbFixedParams(); } + size_t getNbTrainableParams() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbTrainableParams(); } + size_t getParamsSize() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getParamsSize(); } + size_t getNbArithmOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmOps(); } + size_t getNbLogicOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbLogicOps(); } + size_t getNbCompOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbCompOps(); } + size_t getNbMACOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbMACOps(); } + size_t getNbFlops() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbFlops(); } +}; + +template <class OP> class ConvStats : public OperatorStats { public: ConvStats(const Operator& op) : OperatorStats(op) {} - static std::unique_ptr<ConvStats> create(const Operator& op) { - return std::make_unique<ConvStats>(op); + static std::unique_ptr<ConvStats<OP>> create(const Operator& op) { + return std::make_unique<ConvStats<OP>>(op); + } + + size_t getNbMACOps() const override { + const OP& op_ = dynamic_cast<const OP&>(mOp); + const std::size_t kernelSize = std::accumulate( + op_.kernelDims().cbegin(), + op_.kernelDims().cend(), + std::size_t(1), + std::multiplies<std::size_t>()); + // NCHW + const std::size_t outputSize = op_.getOutput(0)->dims()[2] * op_.getOutput(0)->dims()[3]; + return (kernelSize * outputSize); + } +}; + +// Beware: cannot use Conv_Op<2>::Type as key because static variable initialization order is undefined! +REGISTRAR(OperatorStats, "Conv", ConvStats<Conv_Op<2>>::create); +REGISTRAR(OperatorStats, "ConvDepthWise", ConvStats<ConvDepthWise_Op<2>>::create); + +class FCStats : public OperatorStats { +public: + FCStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<FCStats> create(const Operator& op) { + return std::make_unique<FCStats>(op); } - size_t getNbMACOps() const { - return 0; + size_t getNbMACOps() const override { + const FC_Op& op_ = dynamic_cast<const FC_Op&>(mOp); + return op_.getInput(1)->size(); } }; -REGISTRAR(OperatorStats, Conv_Op<2>::Type, ConvStats::create); +REGISTRAR(OperatorStats, "FC", FCStats::create); } #endif /* AIDGE_CORE_GRAPH_STATICANALYSIS_H_ */ diff --git a/python_binding/graph/pybind_StaticAnalysis.cpp b/python_binding/graph/pybind_StaticAnalysis.cpp index 75b552071..6df31f555 100644 --- a/python_binding/graph/pybind_StaticAnalysis.cpp +++ b/python_binding/graph/pybind_StaticAnalysis.cpp @@ -49,14 +49,6 @@ public: ); } - size_t getNbMemAccess() const override { - PYBIND11_OVERRIDE( - size_t, - OperatorStats, - getNbMemAccess - ); - } - size_t getNbArithmOps() const override { PYBIND11_OVERRIDE( size_t, @@ -112,15 +104,20 @@ public: } }; +// See https://pybind11.readthedocs.io/en/stable/advanced/classes.html#binding-protected-member-functions +class StaticAnalysis_Publicist : public StaticAnalysis { +public: + using StaticAnalysis::getOpStats; +}; + void init_StaticAnalysis(py::module& m){ - py::class_<OperatorStats, std::shared_ptr<OperatorStats>, pyOperatorStats>(m, "OperatorStats", py::dynamic_attr()) + py::class_<OperatorStats, std::shared_ptr<OperatorStats>, pyOperatorStats>(m, "OperatorStats", py::multiple_inheritance(), py::dynamic_attr()) .def(py::init<const Operator&>(), py::arg("op")) .def("get_operator", &OperatorStats::getOperator) .def("get_nb_params", &OperatorStats::getNbParams) .def("get_nb_fixed_params", &OperatorStats::getNbFixedParams) .def("get_nb_trainable_params", &OperatorStats::getNbTrainableParams) .def("get_params_size", &OperatorStats::getParamsSize) - .def("get_nb_mem_access", &OperatorStats::getNbMemAccess) .def("get_nb_arithm_ops", &OperatorStats::getNbArithmOps) .def("get_nb_logic_ops", &OperatorStats::getNbLogicOps) .def("get_nb_comp_ops", &OperatorStats::getNbCompOps) @@ -129,10 +126,20 @@ void init_StaticAnalysis(py::module& m){ ; declare_registrable<OperatorStats>(m, "OperatorStats"); - py::class_<StaticAnalysis, std::shared_ptr<StaticAnalysis>, pyStaticAnalysis>(m, "StaticAnalysis", py::dynamic_attr()) + py::class_<StaticAnalysis, std::shared_ptr<StaticAnalysis>, pyStaticAnalysis>(m, "StaticAnalysis", py::multiple_inheritance(), py::dynamic_attr()) .def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph")) .def("get_graph", &StaticAnalysis::getGraph) + .def("get_nb_params", &StaticAnalysis::getNbParams) + .def("get_nb_fixed_params", &StaticAnalysis::getNbFixedParams) + .def("get_nb_trainable_params", &StaticAnalysis::getNbTrainableParams) + .def("get_params_size", &StaticAnalysis::getParamsSize) + .def("get_nb_arithm_ops", &StaticAnalysis::getNbArithmOps) + .def("get_nb_logic_ops", &StaticAnalysis::getNbLogicOps) + .def("get_nb_comp_ops", &StaticAnalysis::getNbCompOps) + .def("get_nb_mac_ops", &StaticAnalysis::getNbMACOps) + .def("get_nb_flops", &StaticAnalysis::getNbFlops) .def("summary", &StaticAnalysis::summary, py::arg("inc_producers") = false) + .def("get_op_stats", &StaticAnalysis_Publicist::getOpStats, py::arg("node")) ; } } diff --git a/src/data/Data.cpp b/src/data/Data.cpp index 91c572897..865b4ff2d 100644 --- a/src/data/Data.cpp +++ b/src/data/Data.cpp @@ -11,6 +11,17 @@ #include "aidge/data/Data.hpp" +bool Aidge::isDataTypeFloatingPoint(const DataType& type) { + switch (type) { + case DataType::Float64: + case DataType::Float32: + case DataType::Float16: + case DataType::BFloat16: return true; + default: return false; + } + return false; +} + size_t Aidge::getDataTypeBitWidth(const DataType& type) { switch (type) { case DataType::Float64: return 64; diff --git a/src/graph/StaticAnalysis.cpp b/src/graph/StaticAnalysis.cpp index fba419516..aa12e29c7 100644 --- a/src/graph/StaticAnalysis.cpp +++ b/src/graph/StaticAnalysis.cpp @@ -53,6 +53,16 @@ size_t Aidge::OperatorStats::getParamsSize() const { return paramsSize; } +size_t Aidge::OperatorStats::getNbFlops() const { + const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp); + if (opTensor) { + if (isDataTypeFloatingPoint(opTensor->getOutput(0)->dataType())) { + return getNbArithmOps(); + } + } + return 0; +} + Aidge::StaticAnalysis::StaticAnalysis(std::shared_ptr<GraphView> graph) : mGraph(graph) { @@ -89,9 +99,7 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const { } } - const auto stats = (Registrar<OperatorStats>::exists(node->type())) - ? Registrar<OperatorStats>::create(node->type())(*(node->getOperator())) - : std::make_shared<OperatorStats>(*(node->getOperator())); + const auto stats = getOpStats(node); nbTrainableParams += stats->getNbTrainableParams(); nbFixedParams += stats->getNbFixedParams(); paramsSize += stats->getParamsSize(); @@ -120,3 +128,21 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const { fmt::println("Estimated Total Size (MB): {}", (inputSize + fwdBwdSize + paramsSize) / 8.0 / 1024 / 1024); fmt::println("--------------------------------------------------------------------------------"); } + +std::shared_ptr<Aidge::OperatorStats> Aidge::StaticAnalysis::getOpStats(std::shared_ptr<Node> node) const { + return (Registrar<OperatorStats>::exists(node->type())) + ? Registrar<OperatorStats>::create(node->type())(*(node->getOperator())) + : (node->getOperator()->isAtomic()) + ? std::make_shared<OperatorStats>(*(node->getOperator())) + : std::make_shared<MetaOpStats>(*(node->getOperator())); +} + +size_t Aidge::StaticAnalysis::accumulate(size_t (OperatorStats::*func)() const) const { + return std::accumulate( + mGraph->getNodes().cbegin(), + mGraph->getNodes().cend(), + std::size_t(0), + [this, func](const size_t& lhs, const std::shared_ptr<Node>& rhs) { + return lhs + (this->getOpStats(rhs).get()->*func)(); + }); +} -- GitLab