From 6f272b8576dbfcee2130e9fe2aeafd399b81a458 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Tue, 22 Oct 2024 17:05:18 +0200
Subject: [PATCH] Improved type of ops and supported ops

---
 aidge_core/static_analysis.py                 |  70 ++++++-
 include/aidge/graph/StaticAnalysis.hpp        | 179 ++++++++++++++++--
 .../graph/pybind_StaticAnalysis.cpp           |  26 ++-
 src/graph/StaticAnalysis.cpp                  |  10 +
 4 files changed, 252 insertions(+), 33 deletions(-)

diff --git a/aidge_core/static_analysis.py b/aidge_core/static_analysis.py
index cb288d06a..91c1c30a2 100644
--- a/aidge_core/static_analysis.py
+++ b/aidge_core/static_analysis.py
@@ -26,17 +26,27 @@ class StaticAnalysisExt(aidge_core.StaticAnalysis):
     def log_nb_comp_ops(self, filename, title=None, log_scale=False):
         self._log_callback(aidge_core.OperatorStats.get_nb_comp_ops, filename, title, log_scale)
 
-    def log_nb_flops(self, filename, title=None, log_scale=False):
-        self._log_callback(aidge_core.OperatorStats.get_nb_flops, filename, title, log_scale)
+    def log_nb_nl_ops(self, filename, title=None, log_scale=False):
+        self._log_callback(aidge_core.OperatorStats.get_nb_nl_ops, filename, title, log_scale)
+
+    def log_nb_mac_ops(self, filename, title=None, log_scale=False):
+        self._log_callback(aidge_core.OperatorStats.get_nb_mac_ops, filename, title, log_scale)
 
     def log_nb_ops(self, filename, title=None, log_scale=False):
-        self._log_callback([aidge_core.OperatorStats.get_nb_arithm_ops,
+        self._log_callback(aidge_core.OperatorStats.get_nb_ops, filename, title, log_scale)
+
+    def log_nb_arithm_int_ops(self, filename, title=None, log_scale=False):
+        self._log_callback(aidge_core.OperatorStats.get_nb_arithm_int_ops, filename, title, log_scale)
+
+    def log_nb_arithm_fp_ops(self, filename, title=None, log_scale=False):
+        self._log_callback(aidge_core.OperatorStats.get_nb_arithm_fp_ops, filename, title, log_scale)
+
+    def log_nb_ops_by_type(self, filename, title=None, log_scale=False):
+        self._log_callback([aidge_core.OperatorStats.get_nb_arithm_int_ops,
+                            aidge_core.OperatorStats.get_nb_arithm_fp_ops,
                             aidge_core.OperatorStats.get_nb_logic_ops,
                             aidge_core.OperatorStats.get_nb_comp_ops,
-                            aidge_core.OperatorStats.get_nb_flops], filename, title, log_scale)
-
-    def log_nb_mac_ops(self, filename, title=None, log_scale=False):
-        self._log_callback(aidge_core.OperatorStats.get_nb_mac_ops, filename, title, log_scale)
+                            aidge_core.OperatorStats.get_nb_nl_ops], filename, title, log_scale)
 
     def _log_callback(self, callback, filename, title=None, log_scale=False):
         """
@@ -56,6 +66,7 @@ class StaticAnalysisExt(aidge_core.StaticAnalysis):
         namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})");
         nodes = self.get_graph().get_ordered_nodes()
         series = []
+        legend = None
 
         for node in nodes:
             if node.type() == "Producer":
@@ -77,14 +88,16 @@ class StaticAnalysisExt(aidge_core.StaticAnalysis):
                 name = (name, attr)
             if isinstance(callback, list):
                 series.append([name, [partial(cb, stats)() for cb in callback]])
-                if title is None: title = str([cb.__name__ for cb in callback])
+                legend = [cb.__name__ for cb in callback]
+                if title is None: title = str(legend)
             else:
                 series.append([name, partial(callback, stats)()])
                 if title is None: title = callback.__name__
 
-        self._log_bar(series, filename, title, log_scale)
+        if title is None: title = str(callback)
+        self._log_bar(series, filename, title, legend, log_scale)
 
-    def _log_bar(self, series, filename, title=None, log_scale=False):
+    def _log_bar(self, series, filename, title=None, legend=None, log_scale=False):
         names, values = zip(*series)
         names_only = [item[0] if isinstance(item, tuple) else item for item in names]
         fig, ax = plt.subplots(figsize=(max(5, len(names)/4), 5))
@@ -116,4 +129,41 @@ class StaticAnalysisExt(aidge_core.StaticAnalysis):
         plt.xticks(rotation='vertical')
         if log_scale: plt.yscale('log')
         if title is not None: plt.title(title)
+        if legend is not None: plt.legend(legend)
+        plt.savefig(filename, bbox_inches='tight')
+
+    def _log_barh(self, series, filename, title=None, legend=None, log_scale=False):
+        names, values = zip(*series)
+        names_only = [item[0] if isinstance(item, tuple) else item for item in names]
+        fig, ax = plt.subplots(figsize=(10, max(5, len(names)/4)))
+        plt.ylim(-0.5, len(names) - 0.5)
+        if isinstance(values[0], list):
+            series = [list(i) for i in zip(*values)]
+            left = np.zeros(len(series[0]))
+            for i, serie in enumerate(series):
+                plt.barh(names_only, serie, left=left)
+                left += serie
+        else:
+            plt.barh(names_only, values)
+        ax.xaxis.minorticks_on()
+        plt.grid(axis='x', which='major', linestyle='--', color='gray')
+        plt.grid(axis='x', which='minor', linestyle=':', color='lightgray')
+        formatter0 = matplotlib.ticker.EngFormatter(unit='')
+        ax.xaxis.set_major_formatter(formatter0)
+        plt.gca().set_axisbelow(True)
+        plt.gca().xaxis.set_label_position('top')
+        plt.gca().xaxis.tick_top()
+
+        labels = plt.gca().get_yticks()
+        tick_labels = plt.gca().get_yticklabels()
+        for i, label in enumerate(labels):
+            if isinstance(names[i], tuple):
+                if 'color' in names[i][1]:
+                    tick_labels[i].set_color(names[i][1]['color'])
+                elif 'fontweight' in names[i][1]:
+                    tick_labels[i].set_fontweight(names[i][1]['fontweight'])
+
+        if log_scale: plt.xscale('log')
+        if title is not None: plt.title(title)
+        if legend is not None: plt.legend(legend)
         plt.savefig(filename, bbox_inches='tight')
diff --git a/include/aidge/graph/StaticAnalysis.hpp b/include/aidge/graph/StaticAnalysis.hpp
index ed2930389..a4b49bf08 100644
--- a/include/aidge/graph/StaticAnalysis.hpp
+++ b/include/aidge/graph/StaticAnalysis.hpp
@@ -26,6 +26,9 @@
 #include "aidge/operator/FC.hpp"
 #include "aidge/operator/MatMul.hpp"
 #include "aidge/operator/ReLU.hpp"
+#include "aidge/operator/ReduceMean.hpp"
+#include "aidge/operator/ReduceSum.hpp"
+#include "aidge/operator/Softmax.hpp"
 #include "aidge/operator/MetaOperator.hpp"
 
 namespace Aidge {
@@ -41,11 +44,73 @@ public:
     virtual size_t getNbFixedParams() const { return 0; };
     virtual size_t getNbTrainableParams() const;
     virtual size_t getParamsSize() const;
+
+    /**
+     * @brief Get the total number of arithmetic operations for the operator.
+     * This includes base arithmetic operations: +, -, / and *.
+     * Example of Operator with only comparison operatons: Conv.
+     * 
+     * @return size_t Number of comparison operations.
+     */
     virtual size_t getNbArithmOps() const { return 2 * getNbMACOps(); };
+
+    /**
+     * @brief Get the total number of logic operations for the operator.
+     * This includes operations like logical shift, or, and...
+     * Example of Operator with only comparison operatons: BitShift.
+     * 
+     * @return size_t Number of comparison operations.
+     */
     virtual size_t getNbLogicOps() const { return 0; };
+
+    /**
+     * @brief Get the total number of comparison operations for the operator.
+     * This includes operations like <, >, =...
+     * Example of Operator with only comparison operatons: MaxPool.
+     * 
+     * @return size_t Number of comparison operations.
+     */
     virtual size_t getNbCompOps() const { return 0; };
-    virtual size_t getNbFlops() const { return 0; };
-    size_t getNbOps() const { return getNbArithmOps() + getNbLogicOps() + getNbCompOps() + getNbFlops(); };
+
+    /**
+     * @brief Get the total number of non-linear (NL) operations for the operator.
+     * This includes operations like calls to tanh(), erf(), cos()...
+     * Example of Operator with only NL operatons: Tanh.
+     * Non-linear operations are necessarily of floating-point type.
+     * 
+     * @return size_t Number of non-linear (NL) operations.
+     */
+    virtual size_t getNbNLOps() const { return 0; };
+
+    /**
+     * @brief Get the total number of operations for the operator.
+     * Total number of operations = arithmetic ops + logic ops + comp ops + NL ops.
+     * 
+     * @return size_t Number of operations.
+     */
+    size_t getNbOps() const { return getNbArithmOps() + getNbLogicOps() + getNbCompOps() + getNbNLOps(); };
+
+    /**
+     * @brief Get the total number of INT arithmetic operations for the operator.
+     * Such that getNbArithmOps() = getNbArithmIntOps() + getNbArithmFpOps()
+     * 
+     * @return size_t Number of INT arithmetic operations.
+     */
+    virtual size_t getNbArithmIntOps() const;
+
+    /**
+     * @brief Get the total number of FP arithmetic operations for the operator.
+     * Such that getNbArithmOps() = getNbArithmIntOps() + getNbArithmFpOps()
+     * 
+     * @return size_t Number of FP arithmetic operations.
+     */
+    size_t getNbArithmFpOps() const { return getNbArithmOps() - getNbArithmIntOps(); };
+
+    /**
+     * @brief Get the total number of MAC operations for the operator.
+     * 
+     * @return size_t Number of MAC operations.
+     */
     virtual size_t getNbMACOps() const { return 0; };
     virtual ~OperatorStats() = default;
 
@@ -68,7 +133,10 @@ public:
     size_t getNbArithmOps() const { return accumulate(&OperatorStats::getNbArithmOps); }
     size_t getNbLogicOps() const { return accumulate(&OperatorStats::getNbLogicOps); }
     size_t getNbCompOps() const { return accumulate(&OperatorStats::getNbCompOps); }
-    size_t getNbFlops() const { return accumulate(&OperatorStats::getNbFlops); }
+    size_t getNbNLOps() const { return accumulate(&OperatorStats::getNbNLOps); }
+    size_t getNbOps() const { return accumulate(&OperatorStats::getNbOps); }
+    size_t getNbArithmIntOps() const { return accumulate(&OperatorStats::getNbArithmIntOps); }
+    size_t getNbArithmFpOps() const { return accumulate(&OperatorStats::getNbArithmFpOps); }
     size_t getNbMACOps() const { return accumulate(&OperatorStats::getNbMACOps); }
     virtual void summary(bool incProducers = false) const;
     virtual ~StaticAnalysis() = default;
@@ -96,7 +164,8 @@ public:
     size_t getNbArithmOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmOps(); }
     size_t getNbLogicOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbLogicOps(); }
     size_t getNbCompOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbCompOps(); }
-    size_t getNbFlops() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbFlops(); }
+    size_t getNbNLOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbNLOps(); }
+    size_t getNbArithmIntOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmIntOps(); }
     size_t getNbMACOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbMACOps(); }
 };
 
@@ -112,8 +181,13 @@ public:
     size_t getNbMACOps() const override {
         const OP& op_ = dynamic_cast<const OP&>(mOp);
 	    const std::size_t weightsSize = op_.getInput(1)->size();
-        const std::size_t outputSize = op_.getOutput(0)->dims()[2] * op_.getOutput(0)->dims()[3]; // NCHW
-        return (weightsSize * outputSize);
+        const std::size_t outputSize
+            = std::accumulate(op_.getOutput(0)->dims().cbegin() + 2,
+                              op_.getOutput(0)->dims().cend(),
+                              1,
+                              std::multiplies<size_t>()); // NCHW...
+        const std::size_t batchSize = op_.getInput(0)->dims()[0]; // NCHW
+        return batchSize * (weightsSize * outputSize);
     }
 };
 
@@ -132,7 +206,8 @@ public:
     size_t getNbMACOps() const override {
         const FC_Op& op_ = dynamic_cast<const FC_Op&>(mOp);
 	    const std::size_t weightsSize = op_.getInput(1)->size();
-        return weightsSize;
+        const std::size_t batchSize = op_.getInput(0)->dims()[0]; // NCHW
+        return batchSize * weightsSize;
     }
 };
 
@@ -182,6 +257,75 @@ public:
 
 REGISTRAR(OperatorStats, "ReLU", ReLUStats::create);
 
+class ReduceMeanStats : public OperatorStats {
+public:
+    ReduceMeanStats(const Operator& op) : OperatorStats(op) {}
+
+    static std::unique_ptr<ReduceMeanStats> create(const Operator& op) {
+        return std::make_unique<ReduceMeanStats>(op);
+    }
+
+    size_t getNbArithmOps() const override {
+        const ReduceMean_Op& op_ = dynamic_cast<const ReduceMean_Op&>(mOp);
+        const size_t nbIn = op_.getInput(0)->size();
+        const size_t nbOut = op_.getOutput(0)->size();
+        const size_t nbReduce = nbIn / nbOut;
+        // (nbReduce - 1) additions + 1 division for each output
+        return nbOut * nbReduce;
+    }
+};
+
+REGISTRAR(OperatorStats, "ReduceMean", ReduceMeanStats::create);
+
+class ReduceSumStats : public OperatorStats {
+public:
+    ReduceSumStats(const Operator& op) : OperatorStats(op) {}
+
+    static std::unique_ptr<ReduceSumStats> create(const Operator& op) {
+        return std::make_unique<ReduceSumStats>(op);
+    }
+
+    size_t getNbArithmOps() const override {
+        const ReduceSum_Op& op_ = dynamic_cast<const ReduceSum_Op&>(mOp);
+        const size_t nbIn = op_.getInput(0)->size();
+        const size_t nbOut = op_.getOutput(0)->size();
+        const size_t nbReduce = nbIn / nbOut;
+        // (nbReduce - 1) additions for each output
+        return nbOut * (nbReduce - 1);
+    }
+};
+
+REGISTRAR(OperatorStats, "ReduceSum", ReduceSumStats::create);
+
+class SoftmaxStats : public OperatorStats {
+public:
+    SoftmaxStats(const Operator& op) : OperatorStats(op) {}
+
+    static std::unique_ptr<SoftmaxStats> create(const Operator& op) {
+        return std::make_unique<SoftmaxStats>(op);
+    }
+
+    size_t getNbArithmOps() const override {
+        const Softmax_Op& op_ = dynamic_cast<const Softmax_Op&>(mOp);
+        const size_t axis = (op_.axis() >= 0) ? op_.axis() : op_.getInput(0)->nbDims() + op_.axis();
+        const size_t nbReduce = op_.getInput(0)->dims()[axis];
+        const size_t nbOut = op_.getOutput(0)->size();
+        // nbOut divisions + (nbReduce - 1) additions
+        return nbOut + (nbReduce - 1);
+    }
+
+    size_t getNbNLOps() const override {
+        const Softmax_Op& op_ = dynamic_cast<const Softmax_Op&>(mOp);
+        const size_t axis = (op_.axis() >= 0) ? op_.axis() : op_.getInput(0)->nbDims() + op_.axis();
+        const size_t nbReduce = op_.getInput(0)->dims()[axis];
+        const size_t nbOut = op_.getOutput(0)->size();
+        // nbOut exp + nbReduce exp
+        return nbOut + nbReduce;
+    }
+};
+
+REGISTRAR(OperatorStats, "Softmax", SoftmaxStats::create);
+
 class MemOpStats : public OperatorStats {
 public:
     MemOpStats(const Operator& op) : OperatorStats(op) {}
@@ -220,25 +364,26 @@ REGISTRAR(OperatorStats, "Sub", ElemWiseOpStats::create);
 REGISTRAR(OperatorStats, "Mul", ElemWiseOpStats::create);
 REGISTRAR(OperatorStats, "Div", ElemWiseOpStats::create);
 
-class ElemWiseFlopStats : public OperatorStats {
+class ElemWiseNLOpStats : public OperatorStats {
 public:
-    ElemWiseFlopStats(const Operator& op) : OperatorStats(op) {}
+    ElemWiseNLOpStats(const Operator& op) : OperatorStats(op) {}
 
-    static std::unique_ptr<ElemWiseFlopStats> create(const Operator& op) {
-        return std::make_unique<ElemWiseFlopStats>(op);
+    static std::unique_ptr<ElemWiseNLOpStats> create(const Operator& op) {
+        return std::make_unique<ElemWiseNLOpStats>(op);
     }
 
-    size_t getNbFlops() const override {
+    size_t getNbNLOps() const override {
         const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp);
         return op_.getOutput(0)->size();
     }
 };
 
-REGISTRAR(OperatorStats, "Sqrt", ElemWiseFlopStats::create);
-REGISTRAR(OperatorStats, "Erf", ElemWiseFlopStats::create);
-REGISTRAR(OperatorStats, "Ln", ElemWiseFlopStats::create);
-REGISTRAR(OperatorStats, "Sigmoid", ElemWiseFlopStats::create);
-REGISTRAR(OperatorStats, "Tanh", ElemWiseFlopStats::create);
+REGISTRAR(OperatorStats, "Sqrt", ElemWiseNLOpStats::create);
+REGISTRAR(OperatorStats, "Erf", ElemWiseNLOpStats::create);
+REGISTRAR(OperatorStats, "Ln", ElemWiseNLOpStats::create);
+REGISTRAR(OperatorStats, "Sigmoid", ElemWiseNLOpStats::create);
+REGISTRAR(OperatorStats, "Tanh", ElemWiseNLOpStats::create);
+REGISTRAR(OperatorStats, "Pow", ElemWiseNLOpStats::create);
 }
 
 #endif /* AIDGE_CORE_GRAPH_STATICANALYSIS_H_ */
diff --git a/python_binding/graph/pybind_StaticAnalysis.cpp b/python_binding/graph/pybind_StaticAnalysis.cpp
index 6df31f555..e1c3744df 100644
--- a/python_binding/graph/pybind_StaticAnalysis.cpp
+++ b/python_binding/graph/pybind_StaticAnalysis.cpp
@@ -73,19 +73,27 @@ public:
         );
     }
 
-    size_t getNbMACOps() const override {
+    size_t getNbNLOps() const override {
         PYBIND11_OVERRIDE(
             size_t,
             OperatorStats,
-            getNbMACOps
+            getNbNLOps
+        );
+    }
+
+    size_t getNbArithmIntOps() const override {
+        PYBIND11_OVERRIDE(
+            size_t,
+            OperatorStats,
+            getNbArithmIntOps
         );
     }
 
-    size_t getNbFlops() const override {
+    size_t getNbMACOps() const override {
         PYBIND11_OVERRIDE(
             size_t,
             OperatorStats,
-            getNbFlops
+            getNbMACOps
         );
     }
 };
@@ -121,8 +129,11 @@ void init_StaticAnalysis(py::module& m){
     .def("get_nb_arithm_ops", &OperatorStats::getNbArithmOps)
     .def("get_nb_logic_ops", &OperatorStats::getNbLogicOps)
     .def("get_nb_comp_ops", &OperatorStats::getNbCompOps)
+    .def("get_nb_nl_ops", &OperatorStats::getNbNLOps)
+    .def("get_nb_ops", &OperatorStats::getNbOps)
+    .def("get_nb_arithm_int_ops", &OperatorStats::getNbArithmIntOps)
+    .def("get_nb_arithm_fp_ops", &OperatorStats::getNbArithmFpOps)
     .def("get_nb_mac_ops", &OperatorStats::getNbMACOps)
-    .def("get_nb_flops", &OperatorStats::getNbFlops)
     ;
     declare_registrable<OperatorStats>(m, "OperatorStats");
 
@@ -136,8 +147,11 @@ void init_StaticAnalysis(py::module& m){
     .def("get_nb_arithm_ops", &StaticAnalysis::getNbArithmOps)
     .def("get_nb_logic_ops", &StaticAnalysis::getNbLogicOps)
     .def("get_nb_comp_ops", &StaticAnalysis::getNbCompOps)
+    .def("get_nb_nl_ops", &StaticAnalysis::getNbNLOps)
+    .def("get_nb_ops", &StaticAnalysis::getNbOps)
+    .def("get_nb_arithm_int_ops", &StaticAnalysis::getNbArithmIntOps)
+    .def("get_nb_arithm_fp_ops", &StaticAnalysis::getNbArithmFpOps)
     .def("get_nb_mac_ops", &StaticAnalysis::getNbMACOps)
-    .def("get_nb_flops", &StaticAnalysis::getNbFlops)
     .def("summary", &StaticAnalysis::summary, py::arg("inc_producers") = false)
     .def("get_op_stats", &StaticAnalysis_Publicist::getOpStats, py::arg("node"))
     ;
diff --git a/src/graph/StaticAnalysis.cpp b/src/graph/StaticAnalysis.cpp
index 6e7cb3bed..ddec32505 100644
--- a/src/graph/StaticAnalysis.cpp
+++ b/src/graph/StaticAnalysis.cpp
@@ -53,6 +53,16 @@ size_t Aidge::OperatorStats::getParamsSize() const {
     return paramsSize;
 }
 
+size_t Aidge::OperatorStats::getNbArithmIntOps() const {
+    const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp);
+    if (opTensor) {
+        if (!isDataTypeFloatingPoint(opTensor->getOutput(0)->dataType())) {
+            return getNbArithmOps();
+        }
+    }
+    return 0;
+}
+
 Aidge::StaticAnalysis::StaticAnalysis(std::shared_ptr<GraphView> graph)
   : mGraph(graph)
 {
-- 
GitLab