From aa939b56b45d8e01618bcd3045d45dcbf2eeea33 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Fri, 11 Oct 2024 17:28:30 +0200
Subject: [PATCH] Improved concept

---
 aidge_core/static_analysis.py                 |  63 ++++++++--
 include/aidge/backend/cpu/data/TensorImpl.hpp |  35 ++----
 include/aidge/data/Data.hpp                   |   1 +
 include/aidge/graph/StaticAnalysis.hpp        | 109 ++++++++++++++----
 .../graph/pybind_StaticAnalysis.cpp           |  29 +++--
 src/data/Data.cpp                             |  11 ++
 src/graph/StaticAnalysis.cpp                  |  32 ++++-
 7 files changed, 212 insertions(+), 68 deletions(-)

diff --git a/aidge_core/static_analysis.py b/aidge_core/static_analysis.py
index ea88af5aa..ddc494efc 100644
--- a/aidge_core/static_analysis.py
+++ b/aidge_core/static_analysis.py
@@ -1,24 +1,66 @@
 import matplotlib.pyplot as plt
+from functools import partial
 import aidge_core
 
 class StaticAnalysisExt(aidge_core.StaticAnalysis):
-    def log_nb_params(self, filename):
+    def log_nb_params(self, filename, title=None, log_scale=False):
+        self._log_callback(aidge_core.OperatorStats.get_nb_params, filename, title, log_scale)
+
+    def log_nb_fixed_params(self, filename, title=None, log_scale=False):
+        self._log_callback(aidge_core.OperatorStats.get_nb_fixed_params, filename, title, log_scale)
+
+    def log_nb_trainable_params(self, filename, title=None, log_scale=False):
+        self._log_callback(aidge_core.OperatorStats.get_nb_trainable_params, filename, title, log_scale)
+
+    def log_params_size(self, filename, title=None, log_scale=False):
+        self._log_callback(aidge_core.OperatorStats.get_params_size, filename, title, log_scale)
+
+    def log_nb_arithm_ops(self, filename, title=None, log_scale=False):
+        self._log_callback(aidge_core.OperatorStats.get_nb_arithm_ops, filename, title, log_scale)
+
+    def log_nb_logic_ops(self, filename, title=None, log_scale=False):
+        self._log_callback(aidge_core.OperatorStats.get_nb_logic_ops, filename, title, log_scale)
+
+    def log_nb_comp_ops(self, filename, title=None, log_scale=False):
+        self._log_callback(aidge_core.OperatorStats.get_nb_comp_ops, filename, title, log_scale)
+
+    def log_nb_mac_ops(self, filename, title=None, log_scale=False):
+        self._log_callback(aidge_core.OperatorStats.get_nb_mac_ops, filename, title, log_scale)
+
+    def log_nb_flops(self, filename, title=None, log_scale=False):
+        self._log_callback(aidge_core.OperatorStats.get_nb_flops, filename, title, log_scale)
+
+    def _log_callback(self, callback, filename, title=None, log_scale=False):
+        """
+        Log a statistic given by an OperatorStats callback member function.
+        Usage:
+
+            stats = StaticAnalysisExt(model)
+            stats.log_callback(aidge_core.OperatorStats.get_nb_params, "stats.png", "Nb params per operator")
+
+        :param func: OperatorStats member function to call.
+        :param filename: Output graph file name.
+        :type filename: str
+        :param title: Title of the graph.
+        :type title: str
+        """
+
         namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})");
         nodes = self.get_graph().get_ordered_nodes()
-        names = []
-        values = []
+        series = []
 
         for node in nodes:
             if node.type() == "Producer":
                 continue
 
-            if node.type() in aidge_core.get_keys_OperatorStats():
-                stats = aidge_core.get_key_value_OperatorStats(node.type())
-            else:
-                stats = aidge_core.OperatorStats(node.get_operator())
-            names.append(namePtrTable[node])
-            values.append(stats.get_nb_params())
+            stats = self.get_op_stats(node)
+            series.append([namePtrTable[node], partial(callback, stats)()])
+
+        if title is None: title = str(callback)
+        self._log_bar(series, filename, title, log_scale)
 
+    def _log_bar(self, series, filename, title=None, log_scale=False):
+        names, values = zip(*series)
         fig, ax = plt.subplots(figsize=(max(5, len(names)/4), 5))
         plt.xlim(-0.5, len(names) - 0.5)
         plt.bar(names, values)
@@ -27,5 +69,6 @@ class StaticAnalysisExt(aidge_core.StaticAnalysis):
         plt.grid(axis='y', which='minor', linestyle=':', color='lightgray')
         plt.gca().set_axisbelow(True)
         plt.xticks(rotation='vertical')
-        plt.title('Number of params per operator')
+        if log_scale: plt.yscale('log')
+        if title is not None: plt.title(title)
         plt.savefig(filename, bbox_inches='tight')
diff --git a/include/aidge/backend/cpu/data/TensorImpl.hpp b/include/aidge/backend/cpu/data/TensorImpl.hpp
index 6454ed233..fd2a0b3f4 100644
--- a/include/aidge/backend/cpu/data/TensorImpl.hpp
+++ b/include/aidge/backend/cpu/data/TensorImpl.hpp
@@ -119,30 +119,17 @@ private:
 template <typename T>
 const std::string TensorImpl_cpu<T>::Backend = "cpu";
 
-namespace {
-static Registrar<Tensor> registrarTensorImpl_cpu_Float64(
-        {"cpu", DataType::Float64}, Aidge::TensorImpl_cpu<double>::create);
-static Registrar<Tensor> registrarTensorImpl_cpu_Float32(
-        {"cpu", DataType::Float32}, Aidge::TensorImpl_cpu<float>::create);
-static Registrar<Tensor> registrarTensorImpl_cpu_Float16(
-        {"cpu", DataType::Float16}, Aidge::TensorImpl_cpu<half_float::half>::create);
-static Registrar<Tensor> registrarTensorImpl_cpu_Int64(
-        {"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<int64_t>::create);
-static Registrar<Tensor> registrarTensorImpl_cpu_Int32(
-        {"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int32_t>::create);
-static Registrar<Tensor> registrarTensorImpl_cpu_Int16(
-        {"cpu", DataType::Int16}, Aidge::TensorImpl_cpu<int16_t>::create);
-static Registrar<Tensor> registrarTensorImpl_cpu_Int8(
-        {"cpu", DataType::Int8}, Aidge::TensorImpl_cpu<int8_t>::create);
-static Registrar<Tensor> registrarTensorImpl_cpu_UInt64(
-        {"cpu", DataType::UInt64}, Aidge::TensorImpl_cpu<uint64_t>::create);
-static Registrar<Tensor> registrarTensorImpl_cpu_UInt32(
-        {"cpu", DataType::UInt32}, Aidge::TensorImpl_cpu<uint32_t>::create);
-static Registrar<Tensor> registrarTensorImpl_cpu_UInt16(
-        {"cpu", DataType::UInt16}, Aidge::TensorImpl_cpu<uint16_t>::create);
-static Registrar<Tensor> registrarTensorImpl_cpu_UInt8(
-        {"cpu", DataType::UInt8}, Aidge::TensorImpl_cpu<uint8_t>::create);
-}  // namespace
+REGISTRAR(Tensor, {"cpu", DataType::Float64}, Aidge::TensorImpl_cpu<double>::create);
+REGISTRAR(Tensor, {"cpu", DataType::Float32}, Aidge::TensorImpl_cpu<float>::create);
+REGISTRAR(Tensor, {"cpu", DataType::Float16}, Aidge::TensorImpl_cpu<half_float::half>::create);
+REGISTRAR(Tensor, {"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<int64_t>::create);
+REGISTRAR(Tensor, {"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int32_t>::create);
+REGISTRAR(Tensor, {"cpu", DataType::Int16}, Aidge::TensorImpl_cpu<int16_t>::create);
+REGISTRAR(Tensor, {"cpu", DataType::Int8}, Aidge::TensorImpl_cpu<int8_t>::create);
+REGISTRAR(Tensor, {"cpu", DataType::UInt64}, Aidge::TensorImpl_cpu<uint64_t>::create);
+REGISTRAR(Tensor, {"cpu", DataType::UInt32}, Aidge::TensorImpl_cpu<uint32_t>::create);
+REGISTRAR(Tensor, {"cpu", DataType::UInt16}, Aidge::TensorImpl_cpu<uint16_t>::create);
+REGISTRAR(Tensor, {"cpu", DataType::UInt8}, Aidge::TensorImpl_cpu<uint8_t>::create);
 }  // namespace Aidge
 
 #endif /* AIDGE_CPU_DATA_TENSORIMPL_H_ */
diff --git a/include/aidge/data/Data.hpp b/include/aidge/data/Data.hpp
index df52b30f8..6f8771942 100644
--- a/include/aidge/data/Data.hpp
+++ b/include/aidge/data/Data.hpp
@@ -52,6 +52,7 @@ enum class DataType {
     Any
 };
 
+bool isDataTypeFloatingPoint(const DataType& type);
 size_t getDataTypeBitWidth(const DataType& type);
 
 enum class DataFormat {
diff --git a/include/aidge/graph/StaticAnalysis.hpp b/include/aidge/graph/StaticAnalysis.hpp
index 091acd0aa..535d10857 100644
--- a/include/aidge/graph/StaticAnalysis.hpp
+++ b/include/aidge/graph/StaticAnalysis.hpp
@@ -22,19 +22,15 @@
 
 #include "aidge/operator/Producer.hpp"
 #include "aidge/operator/Conv.hpp"
+#include "aidge/operator/ConvDepthWise.hpp"
+#include "aidge/operator/FC.hpp"
+#include "aidge/operator/MetaOperator.hpp"
 
 namespace Aidge {
-class StaticAnalysis : public std::enable_shared_from_this<StaticAnalysis> {
-public:
-    StaticAnalysis(std::shared_ptr<GraphView> graph);
-    const std::shared_ptr<GraphView> getGraph() const noexcept { return mGraph; }
-    virtual void summary(bool incProducers = false) const;
-    virtual ~StaticAnalysis() = default;
-
-protected:
-    const std::shared_ptr<GraphView> mGraph;
-};
-
+/**
+ * @brief Base class to compute statistics from an Operator
+ * 
+ */
 class OperatorStats : public Registrable<OperatorStats, std::string, std::function<std::shared_ptr<OperatorStats>(const Operator&)>> {
 public:
     OperatorStats(const Operator& op);
@@ -43,32 +39,105 @@ public:
     virtual size_t getNbFixedParams() const { return 0; };
     virtual size_t getNbTrainableParams() const;
     virtual size_t getParamsSize() const;
-    virtual size_t getNbMemAccess() const  { return 0; };
-    virtual size_t getNbArithmOps() const  { return 2 * getNbMACOps(); };
+    virtual size_t getNbArithmOps() const { return 2 * getNbMACOps(); };
     virtual size_t getNbLogicOps() const { return 0; };
     virtual size_t getNbCompOps() const { return 0; };
-    virtual size_t getNbMACOps() const  { return 0; };
-    virtual size_t getNbFlops() const  { return 0; };
+    virtual size_t getNbMACOps() const { return 0; };
+    virtual size_t getNbFlops() const;
     virtual ~OperatorStats() = default;
 
 protected:
     const Operator &mOp;
 };
 
+/**
+ * @brief Base class to compute statistics from a GraphView
+ * 
+ */
+class StaticAnalysis : public std::enable_shared_from_this<StaticAnalysis> {
+public:
+    StaticAnalysis(std::shared_ptr<GraphView> graph);
+    const std::shared_ptr<GraphView> getGraph() const noexcept { return mGraph; }
+    size_t getNbParams() const { return accumulate(&OperatorStats::getNbParams); }
+    size_t getNbFixedParams() const { return accumulate(&OperatorStats::getNbFixedParams); }
+    size_t getNbTrainableParams() const { return accumulate(&OperatorStats::getNbTrainableParams); }
+    size_t getParamsSize() const { return accumulate(&OperatorStats::getParamsSize); }
+    size_t getNbArithmOps() const { return accumulate(&OperatorStats::getNbArithmOps); }
+    size_t getNbLogicOps() const { return accumulate(&OperatorStats::getNbLogicOps); }
+    size_t getNbCompOps() const { return accumulate(&OperatorStats::getNbCompOps); }
+    size_t getNbMACOps() const { return accumulate(&OperatorStats::getNbMACOps); }
+    size_t getNbFlops() const { return accumulate(&OperatorStats::getNbFlops); }
+    virtual void summary(bool incProducers = false) const;
+    virtual ~StaticAnalysis() = default;
+
+protected:
+    const std::shared_ptr<GraphView> mGraph;
+
+    std::shared_ptr<OperatorStats> getOpStats(std::shared_ptr<Node> node) const;
+    size_t accumulate(size_t (OperatorStats::*func)() const) const;
+};
+
+////////////////////////////////////////////////////////////////////////////////
+
+class MetaOpStats : public OperatorStats {
+public:
+    MetaOpStats(const Operator& op) : OperatorStats(op) {}
+
+    static std::unique_ptr<MetaOpStats> create(const Operator& op) {
+        return std::make_unique<MetaOpStats>(op);
+    }
+
+    size_t getNbFixedParams() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbFixedParams(); }
+    size_t getNbTrainableParams() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbTrainableParams(); }
+    size_t getParamsSize() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getParamsSize(); }
+    size_t getNbArithmOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmOps(); }
+    size_t getNbLogicOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbLogicOps(); }
+    size_t getNbCompOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbCompOps(); }
+    size_t getNbMACOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbMACOps(); }
+    size_t getNbFlops() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbFlops(); }
+};
+
+template <class OP>
 class ConvStats : public OperatorStats {
 public:
     ConvStats(const Operator& op) : OperatorStats(op) {}
 
-    static std::unique_ptr<ConvStats> create(const Operator& op) {
-        return std::make_unique<ConvStats>(op);
+    static std::unique_ptr<ConvStats<OP>> create(const Operator& op) {
+        return std::make_unique<ConvStats<OP>>(op);
+    }
+
+    size_t getNbMACOps() const override {
+        const OP& op_ = dynamic_cast<const OP&>(mOp);
+	    const std::size_t kernelSize = std::accumulate(
+            op_.kernelDims().cbegin(),
+            op_.kernelDims().cend(),
+            std::size_t(1),
+            std::multiplies<std::size_t>());
+        // NCHW
+        const std::size_t outputSize = op_.getOutput(0)->dims()[2] * op_.getOutput(0)->dims()[3];
+        return (kernelSize * outputSize);
+    }
+};
+
+// Beware: cannot use Conv_Op<2>::Type as key because static variable initialization order is undefined!
+REGISTRAR(OperatorStats, "Conv", ConvStats<Conv_Op<2>>::create);
+REGISTRAR(OperatorStats, "ConvDepthWise", ConvStats<ConvDepthWise_Op<2>>::create);
+
+class FCStats : public OperatorStats {
+public:
+    FCStats(const Operator& op) : OperatorStats(op) {}
+
+    static std::unique_ptr<FCStats> create(const Operator& op) {
+        return std::make_unique<FCStats>(op);
     }
 
-    size_t getNbMACOps() const  {
-        return 0;
+    size_t getNbMACOps() const override {
+        const FC_Op& op_ = dynamic_cast<const FC_Op&>(mOp);
+        return op_.getInput(1)->size();
     }
 };
 
-REGISTRAR(OperatorStats, Conv_Op<2>::Type, ConvStats::create);
+REGISTRAR(OperatorStats, "FC", FCStats::create);
 }
 
 #endif /* AIDGE_CORE_GRAPH_STATICANALYSIS_H_ */
diff --git a/python_binding/graph/pybind_StaticAnalysis.cpp b/python_binding/graph/pybind_StaticAnalysis.cpp
index 75b552071..6df31f555 100644
--- a/python_binding/graph/pybind_StaticAnalysis.cpp
+++ b/python_binding/graph/pybind_StaticAnalysis.cpp
@@ -49,14 +49,6 @@ public:
         );
     }
 
-    size_t getNbMemAccess() const override {
-        PYBIND11_OVERRIDE(
-            size_t,
-            OperatorStats,
-            getNbMemAccess
-        );
-    }
-
     size_t getNbArithmOps() const override {
         PYBIND11_OVERRIDE(
             size_t,
@@ -112,15 +104,20 @@ public:
     }
 };
 
+// See https://pybind11.readthedocs.io/en/stable/advanced/classes.html#binding-protected-member-functions
+class StaticAnalysis_Publicist : public StaticAnalysis {
+public:
+    using StaticAnalysis::getOpStats;
+};
+
 void init_StaticAnalysis(py::module& m){
-    py::class_<OperatorStats, std::shared_ptr<OperatorStats>, pyOperatorStats>(m, "OperatorStats", py::dynamic_attr())
+    py::class_<OperatorStats, std::shared_ptr<OperatorStats>, pyOperatorStats>(m, "OperatorStats", py::multiple_inheritance(), py::dynamic_attr())
     .def(py::init<const Operator&>(), py::arg("op"))
     .def("get_operator", &OperatorStats::getOperator)
     .def("get_nb_params", &OperatorStats::getNbParams)
     .def("get_nb_fixed_params", &OperatorStats::getNbFixedParams)
     .def("get_nb_trainable_params", &OperatorStats::getNbTrainableParams)
     .def("get_params_size", &OperatorStats::getParamsSize)
-    .def("get_nb_mem_access", &OperatorStats::getNbMemAccess)
     .def("get_nb_arithm_ops", &OperatorStats::getNbArithmOps)
     .def("get_nb_logic_ops", &OperatorStats::getNbLogicOps)
     .def("get_nb_comp_ops", &OperatorStats::getNbCompOps)
@@ -129,10 +126,20 @@ void init_StaticAnalysis(py::module& m){
     ;
     declare_registrable<OperatorStats>(m, "OperatorStats");
 
-    py::class_<StaticAnalysis, std::shared_ptr<StaticAnalysis>, pyStaticAnalysis>(m, "StaticAnalysis", py::dynamic_attr())
+    py::class_<StaticAnalysis, std::shared_ptr<StaticAnalysis>, pyStaticAnalysis>(m, "StaticAnalysis", py::multiple_inheritance(), py::dynamic_attr())
     .def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph"))
     .def("get_graph", &StaticAnalysis::getGraph)
+    .def("get_nb_params", &StaticAnalysis::getNbParams)
+    .def("get_nb_fixed_params", &StaticAnalysis::getNbFixedParams)
+    .def("get_nb_trainable_params", &StaticAnalysis::getNbTrainableParams)
+    .def("get_params_size", &StaticAnalysis::getParamsSize)
+    .def("get_nb_arithm_ops", &StaticAnalysis::getNbArithmOps)
+    .def("get_nb_logic_ops", &StaticAnalysis::getNbLogicOps)
+    .def("get_nb_comp_ops", &StaticAnalysis::getNbCompOps)
+    .def("get_nb_mac_ops", &StaticAnalysis::getNbMACOps)
+    .def("get_nb_flops", &StaticAnalysis::getNbFlops)
     .def("summary", &StaticAnalysis::summary, py::arg("inc_producers") = false)
+    .def("get_op_stats", &StaticAnalysis_Publicist::getOpStats, py::arg("node"))
     ;
 }
 }
diff --git a/src/data/Data.cpp b/src/data/Data.cpp
index 91c572897..865b4ff2d 100644
--- a/src/data/Data.cpp
+++ b/src/data/Data.cpp
@@ -11,6 +11,17 @@
 
 #include "aidge/data/Data.hpp"
 
+bool Aidge::isDataTypeFloatingPoint(const DataType& type) {
+    switch (type) {
+    case DataType::Float64:
+    case DataType::Float32:
+    case DataType::Float16:
+    case DataType::BFloat16:  return true;
+    default:                  return false;
+    }
+    return false;
+}
+
 size_t Aidge::getDataTypeBitWidth(const DataType& type) {
     switch (type) {
     case DataType::Float64:   return 64;
diff --git a/src/graph/StaticAnalysis.cpp b/src/graph/StaticAnalysis.cpp
index fba419516..aa12e29c7 100644
--- a/src/graph/StaticAnalysis.cpp
+++ b/src/graph/StaticAnalysis.cpp
@@ -53,6 +53,16 @@ size_t Aidge::OperatorStats::getParamsSize() const {
     return paramsSize;
 }
 
+size_t Aidge::OperatorStats::getNbFlops() const {
+    const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp);
+    if (opTensor) {
+        if (isDataTypeFloatingPoint(opTensor->getOutput(0)->dataType())) {
+            return getNbArithmOps();
+        }
+    }
+    return 0;
+}
+
 Aidge::StaticAnalysis::StaticAnalysis(std::shared_ptr<GraphView> graph)
   : mGraph(graph)
 {
@@ -89,9 +99,7 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const {
             }
         }
 
-        const auto stats = (Registrar<OperatorStats>::exists(node->type()))
-            ? Registrar<OperatorStats>::create(node->type())(*(node->getOperator()))
-            : std::make_shared<OperatorStats>(*(node->getOperator()));
+        const auto stats = getOpStats(node);
         nbTrainableParams += stats->getNbTrainableParams();
         nbFixedParams += stats->getNbFixedParams();
         paramsSize += stats->getParamsSize();
@@ -120,3 +128,21 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const {
     fmt::println("Estimated Total Size (MB): {}", (inputSize + fwdBwdSize + paramsSize) / 8.0 / 1024 / 1024);
     fmt::println("--------------------------------------------------------------------------------");
 }
+
+std::shared_ptr<Aidge::OperatorStats> Aidge::StaticAnalysis::getOpStats(std::shared_ptr<Node> node) const {
+    return (Registrar<OperatorStats>::exists(node->type()))
+        ? Registrar<OperatorStats>::create(node->type())(*(node->getOperator()))
+        : (node->getOperator()->isAtomic())
+            ? std::make_shared<OperatorStats>(*(node->getOperator()))
+            : std::make_shared<MetaOpStats>(*(node->getOperator()));
+}
+
+size_t Aidge::StaticAnalysis::accumulate(size_t (OperatorStats::*func)() const) const {
+    return std::accumulate(
+        mGraph->getNodes().cbegin(),
+        mGraph->getNodes().cend(),
+        std::size_t(0),
+        [this, func](const size_t& lhs, const std::shared_ptr<Node>& rhs) {
+            return lhs + (this->getOpStats(rhs).get()->*func)();
+        });
+}
-- 
GitLab