From 4daa23a0cb2fac4060692461e2bd5391f64d486b Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Wed, 9 Oct 2024 14:58:36 +0200
Subject: [PATCH] Initial version of hybrid C++/Python static analysis

---
 include/aidge/data/Data.hpp                   |   2 +
 include/aidge/graph/StaticAnalysis.hpp        |  72 ++++++++++
 .../graph/pybind_StaticAnalysis.cpp           | 136 ++++++++++++++++++
 python_binding/pybind_core.cpp                |   2 +
 src/data/Data.cpp                             |  33 +++++
 src/graph/StaticAnalysis.cpp                  | 122 ++++++++++++++++
 6 files changed, 367 insertions(+)
 create mode 100644 include/aidge/graph/StaticAnalysis.hpp
 create mode 100644 python_binding/graph/pybind_StaticAnalysis.cpp
 create mode 100644 src/graph/StaticAnalysis.cpp

diff --git a/include/aidge/data/Data.hpp b/include/aidge/data/Data.hpp
index 23221e653..df52b30f8 100644
--- a/include/aidge/data/Data.hpp
+++ b/include/aidge/data/Data.hpp
@@ -52,6 +52,8 @@ enum class DataType {
     Any
 };
 
+size_t getDataTypeBitWidth(const DataType& type);
+
 enum class DataFormat {
     Default,
     NCHW,
diff --git a/include/aidge/graph/StaticAnalysis.hpp b/include/aidge/graph/StaticAnalysis.hpp
new file mode 100644
index 000000000..3d63d2575
--- /dev/null
+++ b/include/aidge/graph/StaticAnalysis.hpp
@@ -0,0 +1,72 @@
+
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_CORE_GRAPH_STATICANALYSIS_H_
+#define AIDGE_CORE_GRAPH_STATICANALYSIS_H_
+
+#include <memory>
+
+#include "aidge/utils/Registrar.hpp"
+#include "aidge/graph/GraphView.hpp"
+#include "aidge/data/Tensor.hpp"
+#include "aidge/operator/OperatorTensor.hpp"
+
+#include "aidge/operator/Producer.hpp"
+#include "aidge/operator/Conv.hpp"
+
+namespace Aidge {
+class StaticAnalysis : public std::enable_shared_from_this<StaticAnalysis> {
+public:
+    StaticAnalysis(std::shared_ptr<GraphView> graph);
+    virtual void summary(bool incProducers = false) const;
+    virtual ~StaticAnalysis() = default;
+
+protected:
+    std::shared_ptr<GraphView> mGraph;
+};
+
+class OperatorStats : public Registrable<OperatorStats, std::string, std::function<std::shared_ptr<OperatorStats>(const Operator&)>> {
+public:
+    OperatorStats(const Operator& op);
+    size_t getNbParams() const;
+    virtual size_t getNbFixedParams() const { return 0; };
+    virtual size_t getNbTrainableParams() const;
+    virtual size_t getParamsSize() const;
+    virtual size_t getNbMemAccess() const  { return 0; };
+    virtual size_t getNbArithmOps() const  { return 2 * getNbMACOps(); };
+    virtual size_t getNbLogicOps() const { return 0; };
+    virtual size_t getNbCompOps() const { return 0; };
+    virtual size_t getNbMACOps() const  { return 0; };
+    virtual size_t getNbFlops() const  { return 0; };
+    virtual ~OperatorStats() = default;
+
+protected:
+    const Operator &mOp;
+};
+
+class ConvStats : public OperatorStats {
+public:
+    ConvStats(const Operator& op) : OperatorStats(op) {}
+
+    static std::unique_ptr<ConvStats> create(const Operator& op) {
+        return std::make_unique<ConvStats>(op);
+    }
+
+    size_t getNbMACOps() const  {
+        return 0;
+    }
+};
+
+REGISTRAR(OperatorStats, Conv_Op<2>::Type, ConvStats::create);
+}
+
+#endif /* AIDGE_CORE_GRAPH_STATICANALYSIS_H_ */
diff --git a/python_binding/graph/pybind_StaticAnalysis.cpp b/python_binding/graph/pybind_StaticAnalysis.cpp
new file mode 100644
index 000000000..f9720d461
--- /dev/null
+++ b/python_binding/graph/pybind_StaticAnalysis.cpp
@@ -0,0 +1,136 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <pybind11/pybind11.h>
+#include <pybind11/stl.h>
+
+#include "aidge/graph/StaticAnalysis.hpp"
+
+namespace py = pybind11;
+namespace Aidge {
+
+/**
+ * @brief Trampoline class for binding
+ *
+ */
+class pyOperatorStats: public OperatorStats {
+public:
+    using OperatorStats::OperatorStats; // Inherit constructors
+
+    size_t getNbFixedParams() const override {
+        PYBIND11_OVERRIDE(
+            size_t,
+            OperatorStats,
+            getNbFixedParams
+        );
+    }
+
+    size_t getNbTrainableParams() const override {
+        PYBIND11_OVERRIDE(
+            size_t,
+            OperatorStats,
+            getNbTrainableParams
+        );
+    }
+
+    size_t getParamsSize() const override {
+        PYBIND11_OVERRIDE(
+            size_t,
+            OperatorStats,
+            getParamsSize
+        );
+    }
+
+    size_t getNbMemAccess() const override {
+        PYBIND11_OVERRIDE(
+            size_t,
+            OperatorStats,
+            getNbMemAccess
+        );
+    }
+
+    size_t getNbArithmOps() const override {
+        PYBIND11_OVERRIDE(
+            size_t,
+            OperatorStats,
+            getNbArithmOps
+        );
+    }
+
+    size_t getNbLogicOps() const override {
+        PYBIND11_OVERRIDE(
+            size_t,
+            OperatorStats,
+            getNbLogicOps
+        );
+    }
+
+    size_t getNbCompOps() const override {
+        PYBIND11_OVERRIDE(
+            size_t,
+            OperatorStats,
+            getNbCompOps
+        );
+    }
+
+    size_t getNbMACOps() const override {
+        PYBIND11_OVERRIDE(
+            size_t,
+            OperatorStats,
+            getNbMACOps
+        );
+    }
+
+    size_t getNbFlops() const override {
+        PYBIND11_OVERRIDE(
+            size_t,
+            OperatorStats,
+            getNbFlops
+        );
+    }
+};
+
+class pyStaticAnalysis: public StaticAnalysis {
+public:
+    using StaticAnalysis::StaticAnalysis; // Inherit constructors
+
+    void summary(bool incProducers) const override {
+        PYBIND11_OVERRIDE(
+            void,
+            StaticAnalysis,
+            summary,
+            incProducers
+        );
+    }
+};
+
+void init_StaticAnalysis(py::module& m){
+    py::class_<OperatorStats, std::shared_ptr<OperatorStats>, pyOperatorStats>(m, "OperatorStats", py::dynamic_attr())
+    .def(py::init<const Operator&>(), py::arg("op"))
+    .def("get_nb_params", &OperatorStats::getNbParams)
+    .def("get_nb_fixed_params", &OperatorStats::getNbFixedParams)
+    .def("get_nb_trainable_params", &OperatorStats::getNbTrainableParams)
+    .def("get_params_size", &OperatorStats::getParamsSize)
+    .def("get_nb_mem_access", &OperatorStats::getNbMemAccess)
+    .def("get_nb_arithm_ops", &OperatorStats::getNbArithmOps)
+    .def("get_nb_logic_ops", &OperatorStats::getNbLogicOps)
+    .def("get_nb_comp_ops", &OperatorStats::getNbCompOps)
+    .def("get_nb_mac_ops", &OperatorStats::getNbMACOps)
+    .def("get_nb_flops", &OperatorStats::getNbFlops)
+    ;
+    declare_registrable<OperatorStats>(m, "OperatorStats");
+
+    py::class_<StaticAnalysis, std::shared_ptr<StaticAnalysis>, pyStaticAnalysis>(m, "StaticAnalysis", py::dynamic_attr())
+    .def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph"))
+    .def("summary", &StaticAnalysis::summary, py::arg("inc_producers") = false)
+    ;
+}
+}
diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp
index 02f4b732c..c287314f2 100644
--- a/python_binding/pybind_core.cpp
+++ b/python_binding/pybind_core.cpp
@@ -27,6 +27,7 @@ void init_OperatorImpl(py::module&);
 void init_Log(py::module&);
 void init_Operator(py::module&);
 void init_OperatorTensor(py::module&);
+void init_StaticAnalysis(py::module&);
 
 void init_Add(py::module&);
 void init_And(py::module&);
@@ -117,6 +118,7 @@ void init_Aidge(py::module& m) {
     init_Log(m);
     init_Operator(m);
     init_OperatorTensor(m);
+    init_StaticAnalysis(m);
 
     init_Add(m);
     init_And(m);
diff --git a/src/data/Data.cpp b/src/data/Data.cpp
index 62a883d08..91c572897 100644
--- a/src/data/Data.cpp
+++ b/src/data/Data.cpp
@@ -11,6 +11,39 @@
 
 #include "aidge/data/Data.hpp"
 
+size_t Aidge::getDataTypeBitWidth(const DataType& type) {
+    switch (type) {
+    case DataType::Float64:   return 64;
+    case DataType::Float32:   return 32;
+    case DataType::Float16:   return 16;
+    case DataType::BFloat16:  return 16;
+    case DataType::Binary:    return 1;
+    case DataType::Ternary:   return 2;
+    case DataType::Int2:      return 2;
+    case DataType::Int3:      return 3;
+    case DataType::Int4:      return 4;
+    case DataType::Int5:      return 5;
+    case DataType::Int6:      return 6;
+    case DataType::Int7:      return 7;
+    case DataType::Int8:      return 8;
+    case DataType::Int16:     return 16;
+    case DataType::Int32:     return 32;
+    case DataType::Int64:     return 64;
+    case DataType::UInt2:     return 2;
+    case DataType::UInt3:     return 3;
+    case DataType::UInt4:     return 4;
+    case DataType::UInt5:     return 5;
+    case DataType::UInt6:     return 6;
+    case DataType::UInt7:     return 7;
+    case DataType::UInt8:     return 8;
+    case DataType::UInt16:    return 16;
+    case DataType::UInt32:    return 32;
+    case DataType::UInt64:    return 64;
+    default:                  return 0;
+    }
+    return 0;
+}
+
 Aidge::DataFormatTranspose Aidge::getDataFormatTranspose(const DataFormat& src, const DataFormat& dst) {
     // Permutation array from default format to src format
     const auto srcDefToFormat = DataFormatTransposeDict[static_cast<int>(src)];
diff --git a/src/graph/StaticAnalysis.cpp b/src/graph/StaticAnalysis.cpp
new file mode 100644
index 000000000..ec08ae5da
--- /dev/null
+++ b/src/graph/StaticAnalysis.cpp
@@ -0,0 +1,122 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include "aidge/graph/StaticAnalysis.hpp"
+
+Aidge::OperatorStats::OperatorStats(const Operator& op)
+  : mOp(op)
+{
+    //ctor
+}
+
+size_t Aidge::OperatorStats::getNbParams() const {
+    return (getNbFixedParams() + getNbTrainableParams());
+}
+
+size_t Aidge::OperatorStats::getNbTrainableParams() const {
+    size_t nbParams = 0;
+    const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp);
+    if (opTensor) {
+        for (size_t i = 0; i < mOp.nbInputs(); ++i) {
+            if ((mOp.inputCategory(i) == InputCategory::Param
+                    || mOp.inputCategory(i) == InputCategory::OptionalParam)
+                && opTensor->getInput(i))
+            {
+                nbParams += opTensor->getInput(i)->size();
+            }
+        }
+    }
+    return nbParams;
+}
+
+size_t Aidge::OperatorStats::getParamsSize() const {
+    size_t paramsSize = 0;
+    const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp);
+    if (opTensor) {
+        for (size_t i = 0; i < mOp.nbInputs(); ++i) {
+            if ((mOp.inputCategory(i) == InputCategory::Param
+                    || mOp.inputCategory(i) == InputCategory::OptionalParam)
+                && opTensor->getInput(i))
+            {
+                paramsSize += opTensor->getInput(i)->size() * getDataTypeBitWidth(opTensor->getInput(i)->dataType());
+            }
+        }
+    }
+    return paramsSize;
+}
+
+Aidge::StaticAnalysis::StaticAnalysis(std::shared_ptr<GraphView> graph)
+  : mGraph(graph)
+{
+    //ctor
+}
+
+void Aidge::StaticAnalysis::summary(bool incProducers) const {
+    fmt::println("--------------------------------------------------------------------------------");
+    fmt::println("                        Layer (type)               Output Shape         Param #");
+    fmt::println("================================================================================");
+
+    size_t nbTrainableParams = 0;
+    size_t nbFixedParams = 0;
+    size_t paramsSize = 0;
+    size_t fwdBwdSize = 0;
+
+    const auto namePtrTable = mGraph->getRankedNodesName("{0} ({1}#{3})");
+    for (const auto node : mGraph->getOrderedNodes()) {
+        if (node->type() == Producer_Op::Type && !incProducers) {
+            continue;
+        }
+
+        auto opTensor = std::dynamic_pointer_cast<OperatorTensor>(node->getOperator());
+        std::string outputDimsStr = fmt::format("{: >27}", "?");
+        if (opTensor) {
+            const auto outputDims = opTensor->getOutput(0)->dims();
+            outputDimsStr = fmt::format("{: >27}", fmt::format("{}", outputDims));
+  
+            for (size_t out = 0; out < node->nbOutputs(); ++out) {
+                const auto output = opTensor->getOutput(out);
+                if (output && node->type() != Producer_Op::Type) {
+                    fwdBwdSize += output->size();
+                }
+            }
+        }
+
+        const auto stats = (Registrar<OperatorStats>::exists(node->type()))
+            ? Registrar<OperatorStats>::create(node->type())(*(node->getOperator()))
+            : std::make_shared<OperatorStats>(*(node->getOperator()));
+        nbTrainableParams += stats->getNbTrainableParams();
+        nbFixedParams += stats->getNbFixedParams();
+        paramsSize += stats->getParamsSize();
+        fmt::println("{: >36}{}{: >16}",
+          namePtrTable.at(node), outputDimsStr, stats->getNbParams());
+    }
+
+    size_t inputSize = 0;
+    for (const auto input : mGraph->getOrderedInputs()) {
+        if (input.first) {
+            auto opTensor = std::dynamic_pointer_cast<OperatorTensor>(input.first->getOperator());
+            if (opTensor && opTensor->getInput(input.second)) {
+                inputSize += opTensor->getInput(input.second)->size();
+            }
+        }
+    }
+
+    fmt::println("================================================================================");
+    fmt::println("Total params: {}", nbTrainableParams + nbFixedParams);
+    fmt::println("Trainable params: {}", nbTrainableParams);
+    fmt::println("Non-trainable params: {}", nbFixedParams);
+    fmt::println("--------------------------------------------------------------------------------");
+    fmt::println("Input size (MB): {}", inputSize / 8 / 1024 / 1024);
+    fmt::println("Forward/backward pass size (MB): {}", fwdBwdSize / 8 / 1024 / 1024);
+    fmt::println("Params size (MB): {}", paramsSize / 8 / 1024 / 1024);
+    fmt::println("Estimated Total Size (MB): {}", (inputSize + fwdBwdSize + paramsSize) / 8 / 1024 / 1024);
+    fmt::println("--------------------------------------------------------------------------------");
+}
-- 
GitLab