From 4daa23a0cb2fac4060692461e2bd5391f64d486b Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Wed, 9 Oct 2024 14:58:36 +0200 Subject: [PATCH] Initial version of hybrid C++/Python static analysis --- include/aidge/data/Data.hpp | 2 + include/aidge/graph/StaticAnalysis.hpp | 72 ++++++++++ .../graph/pybind_StaticAnalysis.cpp | 136 ++++++++++++++++++ python_binding/pybind_core.cpp | 2 + src/data/Data.cpp | 33 +++++ src/graph/StaticAnalysis.cpp | 122 ++++++++++++++++ 6 files changed, 367 insertions(+) create mode 100644 include/aidge/graph/StaticAnalysis.hpp create mode 100644 python_binding/graph/pybind_StaticAnalysis.cpp create mode 100644 src/graph/StaticAnalysis.cpp diff --git a/include/aidge/data/Data.hpp b/include/aidge/data/Data.hpp index 23221e653..df52b30f8 100644 --- a/include/aidge/data/Data.hpp +++ b/include/aidge/data/Data.hpp @@ -52,6 +52,8 @@ enum class DataType { Any }; +size_t getDataTypeBitWidth(const DataType& type); + enum class DataFormat { Default, NCHW, diff --git a/include/aidge/graph/StaticAnalysis.hpp b/include/aidge/graph/StaticAnalysis.hpp new file mode 100644 index 000000000..3d63d2575 --- /dev/null +++ b/include/aidge/graph/StaticAnalysis.hpp @@ -0,0 +1,72 @@ + +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_GRAPH_STATICANALYSIS_H_ +#define AIDGE_CORE_GRAPH_STATICANALYSIS_H_ + +#include <memory> + +#include "aidge/utils/Registrar.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/operator/OperatorTensor.hpp" + +#include "aidge/operator/Producer.hpp" +#include "aidge/operator/Conv.hpp" + +namespace Aidge { +class StaticAnalysis : public std::enable_shared_from_this<StaticAnalysis> { +public: + StaticAnalysis(std::shared_ptr<GraphView> graph); + virtual void summary(bool incProducers = false) const; + virtual ~StaticAnalysis() = default; + +protected: + std::shared_ptr<GraphView> mGraph; +}; + +class OperatorStats : public Registrable<OperatorStats, std::string, std::function<std::shared_ptr<OperatorStats>(const Operator&)>> { +public: + OperatorStats(const Operator& op); + size_t getNbParams() const; + virtual size_t getNbFixedParams() const { return 0; }; + virtual size_t getNbTrainableParams() const; + virtual size_t getParamsSize() const; + virtual size_t getNbMemAccess() const { return 0; }; + virtual size_t getNbArithmOps() const { return 2 * getNbMACOps(); }; + virtual size_t getNbLogicOps() const { return 0; }; + virtual size_t getNbCompOps() const { return 0; }; + virtual size_t getNbMACOps() const { return 0; }; + virtual size_t getNbFlops() const { return 0; }; + virtual ~OperatorStats() = default; + +protected: + const Operator &mOp; +}; + +class ConvStats : public OperatorStats { +public: + ConvStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<ConvStats> create(const Operator& op) { + return std::make_unique<ConvStats>(op); + } + + size_t getNbMACOps() const { + return 0; + } +}; + +REGISTRAR(OperatorStats, Conv_Op<2>::Type, ConvStats::create); +} + +#endif /* AIDGE_CORE_GRAPH_STATICANALYSIS_H_ */ diff --git a/python_binding/graph/pybind_StaticAnalysis.cpp b/python_binding/graph/pybind_StaticAnalysis.cpp new file mode 100644 index 000000000..f9720d461 --- /dev/null +++ b/python_binding/graph/pybind_StaticAnalysis.cpp @@ -0,0 +1,136 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include "aidge/graph/StaticAnalysis.hpp" + +namespace py = pybind11; +namespace Aidge { + +/** + * @brief Trampoline class for binding + * + */ +class pyOperatorStats: public OperatorStats { +public: + using OperatorStats::OperatorStats; // Inherit constructors + + size_t getNbFixedParams() const override { + PYBIND11_OVERRIDE( + size_t, + OperatorStats, + getNbFixedParams + ); + } + + size_t getNbTrainableParams() const override { + PYBIND11_OVERRIDE( + size_t, + OperatorStats, + getNbTrainableParams + ); + } + + size_t getParamsSize() const override { + PYBIND11_OVERRIDE( + size_t, + OperatorStats, + getParamsSize + ); + } + + size_t getNbMemAccess() const override { + PYBIND11_OVERRIDE( + size_t, + OperatorStats, + getNbMemAccess + ); + } + + size_t getNbArithmOps() const override { + PYBIND11_OVERRIDE( + size_t, + OperatorStats, + getNbArithmOps + ); + } + + size_t getNbLogicOps() const override { + PYBIND11_OVERRIDE( + size_t, + OperatorStats, + getNbLogicOps + ); + } + + size_t getNbCompOps() const override { + PYBIND11_OVERRIDE( + size_t, + OperatorStats, + getNbCompOps + ); + } + + size_t getNbMACOps() const override { + PYBIND11_OVERRIDE( + size_t, + OperatorStats, + getNbMACOps + ); + } + + size_t getNbFlops() const override { + PYBIND11_OVERRIDE( + size_t, + OperatorStats, + getNbFlops + ); + } +}; + +class pyStaticAnalysis: public StaticAnalysis { +public: + using StaticAnalysis::StaticAnalysis; // Inherit constructors + + void summary(bool incProducers) const override { + PYBIND11_OVERRIDE( + void, + StaticAnalysis, + summary, + incProducers + ); + } +}; + +void init_StaticAnalysis(py::module& m){ + py::class_<OperatorStats, std::shared_ptr<OperatorStats>, pyOperatorStats>(m, "OperatorStats", py::dynamic_attr()) + .def(py::init<const Operator&>(), py::arg("op")) + .def("get_nb_params", &OperatorStats::getNbParams) + .def("get_nb_fixed_params", &OperatorStats::getNbFixedParams) + .def("get_nb_trainable_params", &OperatorStats::getNbTrainableParams) + .def("get_params_size", &OperatorStats::getParamsSize) + .def("get_nb_mem_access", &OperatorStats::getNbMemAccess) + .def("get_nb_arithm_ops", &OperatorStats::getNbArithmOps) + .def("get_nb_logic_ops", &OperatorStats::getNbLogicOps) + .def("get_nb_comp_ops", &OperatorStats::getNbCompOps) + .def("get_nb_mac_ops", &OperatorStats::getNbMACOps) + .def("get_nb_flops", &OperatorStats::getNbFlops) + ; + declare_registrable<OperatorStats>(m, "OperatorStats"); + + py::class_<StaticAnalysis, std::shared_ptr<StaticAnalysis>, pyStaticAnalysis>(m, "StaticAnalysis", py::dynamic_attr()) + .def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph")) + .def("summary", &StaticAnalysis::summary, py::arg("inc_producers") = false) + ; +} +} diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index 02f4b732c..c287314f2 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -27,6 +27,7 @@ void init_OperatorImpl(py::module&); void init_Log(py::module&); void init_Operator(py::module&); void init_OperatorTensor(py::module&); +void init_StaticAnalysis(py::module&); void init_Add(py::module&); void init_And(py::module&); @@ -117,6 +118,7 @@ void init_Aidge(py::module& m) { init_Log(m); init_Operator(m); init_OperatorTensor(m); + init_StaticAnalysis(m); init_Add(m); init_And(m); diff --git a/src/data/Data.cpp b/src/data/Data.cpp index 62a883d08..91c572897 100644 --- a/src/data/Data.cpp +++ b/src/data/Data.cpp @@ -11,6 +11,39 @@ #include "aidge/data/Data.hpp" +size_t Aidge::getDataTypeBitWidth(const DataType& type) { + switch (type) { + case DataType::Float64: return 64; + case DataType::Float32: return 32; + case DataType::Float16: return 16; + case DataType::BFloat16: return 16; + case DataType::Binary: return 1; + case DataType::Ternary: return 2; + case DataType::Int2: return 2; + case DataType::Int3: return 3; + case DataType::Int4: return 4; + case DataType::Int5: return 5; + case DataType::Int6: return 6; + case DataType::Int7: return 7; + case DataType::Int8: return 8; + case DataType::Int16: return 16; + case DataType::Int32: return 32; + case DataType::Int64: return 64; + case DataType::UInt2: return 2; + case DataType::UInt3: return 3; + case DataType::UInt4: return 4; + case DataType::UInt5: return 5; + case DataType::UInt6: return 6; + case DataType::UInt7: return 7; + case DataType::UInt8: return 8; + case DataType::UInt16: return 16; + case DataType::UInt32: return 32; + case DataType::UInt64: return 64; + default: return 0; + } + return 0; +} + Aidge::DataFormatTranspose Aidge::getDataFormatTranspose(const DataFormat& src, const DataFormat& dst) { // Permutation array from default format to src format const auto srcDefToFormat = DataFormatTransposeDict[static_cast<int>(src)]; diff --git a/src/graph/StaticAnalysis.cpp b/src/graph/StaticAnalysis.cpp new file mode 100644 index 000000000..ec08ae5da --- /dev/null +++ b/src/graph/StaticAnalysis.cpp @@ -0,0 +1,122 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/graph/StaticAnalysis.hpp" + +Aidge::OperatorStats::OperatorStats(const Operator& op) + : mOp(op) +{ + //ctor +} + +size_t Aidge::OperatorStats::getNbParams() const { + return (getNbFixedParams() + getNbTrainableParams()); +} + +size_t Aidge::OperatorStats::getNbTrainableParams() const { + size_t nbParams = 0; + const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp); + if (opTensor) { + for (size_t i = 0; i < mOp.nbInputs(); ++i) { + if ((mOp.inputCategory(i) == InputCategory::Param + || mOp.inputCategory(i) == InputCategory::OptionalParam) + && opTensor->getInput(i)) + { + nbParams += opTensor->getInput(i)->size(); + } + } + } + return nbParams; +} + +size_t Aidge::OperatorStats::getParamsSize() const { + size_t paramsSize = 0; + const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp); + if (opTensor) { + for (size_t i = 0; i < mOp.nbInputs(); ++i) { + if ((mOp.inputCategory(i) == InputCategory::Param + || mOp.inputCategory(i) == InputCategory::OptionalParam) + && opTensor->getInput(i)) + { + paramsSize += opTensor->getInput(i)->size() * getDataTypeBitWidth(opTensor->getInput(i)->dataType()); + } + } + } + return paramsSize; +} + +Aidge::StaticAnalysis::StaticAnalysis(std::shared_ptr<GraphView> graph) + : mGraph(graph) +{ + //ctor +} + +void Aidge::StaticAnalysis::summary(bool incProducers) const { + fmt::println("--------------------------------------------------------------------------------"); + fmt::println(" Layer (type) Output Shape Param #"); + fmt::println("================================================================================"); + + size_t nbTrainableParams = 0; + size_t nbFixedParams = 0; + size_t paramsSize = 0; + size_t fwdBwdSize = 0; + + const auto namePtrTable = mGraph->getRankedNodesName("{0} ({1}#{3})"); + for (const auto node : mGraph->getOrderedNodes()) { + if (node->type() == Producer_Op::Type && !incProducers) { + continue; + } + + auto opTensor = std::dynamic_pointer_cast<OperatorTensor>(node->getOperator()); + std::string outputDimsStr = fmt::format("{: >27}", "?"); + if (opTensor) { + const auto outputDims = opTensor->getOutput(0)->dims(); + outputDimsStr = fmt::format("{: >27}", fmt::format("{}", outputDims)); + + for (size_t out = 0; out < node->nbOutputs(); ++out) { + const auto output = opTensor->getOutput(out); + if (output && node->type() != Producer_Op::Type) { + fwdBwdSize += output->size(); + } + } + } + + const auto stats = (Registrar<OperatorStats>::exists(node->type())) + ? Registrar<OperatorStats>::create(node->type())(*(node->getOperator())) + : std::make_shared<OperatorStats>(*(node->getOperator())); + nbTrainableParams += stats->getNbTrainableParams(); + nbFixedParams += stats->getNbFixedParams(); + paramsSize += stats->getParamsSize(); + fmt::println("{: >36}{}{: >16}", + namePtrTable.at(node), outputDimsStr, stats->getNbParams()); + } + + size_t inputSize = 0; + for (const auto input : mGraph->getOrderedInputs()) { + if (input.first) { + auto opTensor = std::dynamic_pointer_cast<OperatorTensor>(input.first->getOperator()); + if (opTensor && opTensor->getInput(input.second)) { + inputSize += opTensor->getInput(input.second)->size(); + } + } + } + + fmt::println("================================================================================"); + fmt::println("Total params: {}", nbTrainableParams + nbFixedParams); + fmt::println("Trainable params: {}", nbTrainableParams); + fmt::println("Non-trainable params: {}", nbFixedParams); + fmt::println("--------------------------------------------------------------------------------"); + fmt::println("Input size (MB): {}", inputSize / 8 / 1024 / 1024); + fmt::println("Forward/backward pass size (MB): {}", fwdBwdSize / 8 / 1024 / 1024); + fmt::println("Params size (MB): {}", paramsSize / 8 / 1024 / 1024); + fmt::println("Estimated Total Size (MB): {}", (inputSize + fwdBwdSize + paramsSize) / 8 / 1024 / 1024); + fmt::println("--------------------------------------------------------------------------------"); +} -- GitLab