diff --git a/aidge_core/static_analysis.py b/aidge_core/static_analysis.py index 91c1c30a22a6ad994934da588b09b70e8868142b..add0c991004304770056bde13aadc531eb6ce4cc 100644 --- a/aidge_core/static_analysis.py +++ b/aidge_core/static_analysis.py @@ -6,16 +6,36 @@ import aidge_core class StaticAnalysisExt(aidge_core.StaticAnalysis): def log_nb_params(self, filename, title=None, log_scale=False): - self._log_callback(aidge_core.OperatorStats.get_nb_params, filename, title, log_scale) + namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})"); + nodes = self.get_graph().get_ordered_nodes() + series = [] + legend = None - def log_nb_fixed_params(self, filename, title=None, log_scale=False): - self._log_callback(aidge_core.OperatorStats.get_nb_fixed_params, filename, title, log_scale) + for node in nodes: + if node.type() == "Producer": + continue - def log_nb_trainable_params(self, filename, title=None, log_scale=False): - self._log_callback(aidge_core.OperatorStats.get_nb_trainable_params, filename, title, log_scale) + name = namePtrTable[node] + series.append([name, self.get_nb_params(node)]) + + if title is None: title = "log_nb_params" + self._log_bar(series, filename, title, legend, log_scale) def log_params_size(self, filename, title=None, log_scale=False): - self._log_callback(aidge_core.OperatorStats.get_params_size, filename, title, log_scale) + namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})"); + nodes = self.get_graph().get_ordered_nodes() + series = [] + legend = None + + for node in nodes: + if node.type() == "Producer": + continue + + name = namePtrTable[node] + series.append([name, self.log_params_size(node)]) + + if title is None: title = "log_params_size" + self._log_bar(series, filename, title, legend, log_scale) def log_nb_arithm_ops(self, filename, title=None, log_scale=False): self._log_callback(aidge_core.OperatorStats.get_nb_arithm_ops, filename, title, log_scale) diff --git a/include/aidge/graph/StaticAnalysis.hpp b/include/aidge/graph/StaticAnalysis.hpp index 3b6747dfd9e721c5a3f477ae4cd10935a51602c0..16592041ce04e16930a757578e4f42db453ed645 100644 --- a/include/aidge/graph/StaticAnalysis.hpp +++ b/include/aidge/graph/StaticAnalysis.hpp @@ -33,17 +33,13 @@ namespace Aidge { /** - * @brief Base class to compute statistics from an Operator + * @brief Base class to compute statistics from an Operator. * */ class OperatorStats : public Registrable<OperatorStats, std::string, std::function<std::shared_ptr<OperatorStats>(const Operator&)>> { public: OperatorStats(const Operator& op); const Operator& getOperator() const noexcept { return mOp; } - size_t getNbParams() const; - virtual size_t getNbFixedParams() const { return 0; }; - virtual size_t getNbTrainableParams() const; - virtual size_t getParamsSize() const; /** * @brief Get the total number of arithmetic operations for the operator. @@ -126,10 +122,35 @@ class StaticAnalysis : public std::enable_shared_from_this<StaticAnalysis> { public: StaticAnalysis(std::shared_ptr<GraphView> graph); const std::shared_ptr<GraphView> getGraph() const noexcept { return mGraph; } - size_t getNbParams() const { return accumulate(&OperatorStats::getNbParams); } - size_t getNbFixedParams() const { return accumulate(&OperatorStats::getNbFixedParams); } - size_t getNbTrainableParams() const { return accumulate(&OperatorStats::getNbTrainableParams); } - size_t getParamsSize() const { return accumulate(&OperatorStats::getParamsSize); } + + /** + * @brief Get the number of parameters associated to a node. This includes + * all Producers directly connected to the node's inputs as well as all + * internal Producers (in case of a meta operator). + * + * Note: this function does not check if parameters are shared between + * several nodes or not. This means that simply adding parameters count from + * several nodes may lead to a higher number of parameters than in reality. + * + * @param node Node + * @return size_t Number of parameters + */ + virtual size_t getNbParams(std::shared_ptr<Node> node) const; + + /** + * @brief Get the total parameters memory size, in bits, associated to a node. + * This includes all Producers directly connected to the node's inputs as + * well as all internal Producers (in case of a meta operator). + * + * Note: this function does not check if parameters are shared between + * several nodes or not. This means that simply adding parameters size from + * several nodes may lead to a higher parameter size than in reality. + * + * @param node Node + * @return size_t Total parameters memory, in bits + */ + virtual size_t getParamsSize(std::shared_ptr<Node> node) const; + size_t getNbArithmOps() const { return accumulate(&OperatorStats::getNbArithmOps); } size_t getNbLogicOps() const { return accumulate(&OperatorStats::getNbLogicOps); } size_t getNbCompOps() const { return accumulate(&OperatorStats::getNbCompOps); } @@ -158,9 +179,6 @@ public: return std::make_unique<MetaOpStats>(op); } - size_t getNbFixedParams() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbFixedParams(); } - size_t getNbTrainableParams() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbTrainableParams(); } - size_t getParamsSize() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getParamsSize(); } size_t getNbArithmOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmOps(); } size_t getNbLogicOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbLogicOps(); } size_t getNbCompOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbCompOps(); } diff --git a/python_binding/graph/pybind_StaticAnalysis.cpp b/python_binding/graph/pybind_StaticAnalysis.cpp index e1c3744df142472e5e44dae9d1d540724c6a9c20..04a19f6cce9a0d697d4a735161fa887c395f6130 100644 --- a/python_binding/graph/pybind_StaticAnalysis.cpp +++ b/python_binding/graph/pybind_StaticAnalysis.cpp @@ -25,30 +25,6 @@ class pyOperatorStats: public OperatorStats { public: using OperatorStats::OperatorStats; // Inherit constructors - size_t getNbFixedParams() const override { - PYBIND11_OVERRIDE( - size_t, - OperatorStats, - getNbFixedParams - ); - } - - size_t getNbTrainableParams() const override { - PYBIND11_OVERRIDE( - size_t, - OperatorStats, - getNbTrainableParams - ); - } - - size_t getParamsSize() const override { - PYBIND11_OVERRIDE( - size_t, - OperatorStats, - getParamsSize - ); - } - size_t getNbArithmOps() const override { PYBIND11_OVERRIDE( size_t, @@ -102,6 +78,24 @@ class pyStaticAnalysis: public StaticAnalysis { public: using StaticAnalysis::StaticAnalysis; // Inherit constructors + size_t getNbParams(std::shared_ptr<Node> node) const override { + PYBIND11_OVERRIDE( + size_t, + StaticAnalysis, + getNbParams, + node + ); + } + + size_t getParamsSize(std::shared_ptr<Node> node) const override { + PYBIND11_OVERRIDE( + size_t, + StaticAnalysis, + getParamsSize, + node + ); + } + void summary(bool incProducers) const override { PYBIND11_OVERRIDE( void, @@ -122,10 +116,6 @@ void init_StaticAnalysis(py::module& m){ py::class_<OperatorStats, std::shared_ptr<OperatorStats>, pyOperatorStats>(m, "OperatorStats", py::multiple_inheritance(), py::dynamic_attr()) .def(py::init<const Operator&>(), py::arg("op")) .def("get_operator", &OperatorStats::getOperator) - .def("get_nb_params", &OperatorStats::getNbParams) - .def("get_nb_fixed_params", &OperatorStats::getNbFixedParams) - .def("get_nb_trainable_params", &OperatorStats::getNbTrainableParams) - .def("get_params_size", &OperatorStats::getParamsSize) .def("get_nb_arithm_ops", &OperatorStats::getNbArithmOps) .def("get_nb_logic_ops", &OperatorStats::getNbLogicOps) .def("get_nb_comp_ops", &OperatorStats::getNbCompOps) @@ -140,10 +130,8 @@ void init_StaticAnalysis(py::module& m){ py::class_<StaticAnalysis, std::shared_ptr<StaticAnalysis>, pyStaticAnalysis>(m, "StaticAnalysis", py::multiple_inheritance(), py::dynamic_attr()) .def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph")) .def("get_graph", &StaticAnalysis::getGraph) - .def("get_nb_params", &StaticAnalysis::getNbParams) - .def("get_nb_fixed_params", &StaticAnalysis::getNbFixedParams) - .def("get_nb_trainable_params", &StaticAnalysis::getNbTrainableParams) - .def("get_params_size", &StaticAnalysis::getParamsSize) + .def("get_nb_params", &StaticAnalysis::getNbParams, py::arg("node")) + .def("get_params_size", &StaticAnalysis::getParamsSize, py::arg("node")) .def("get_nb_arithm_ops", &StaticAnalysis::getNbArithmOps) .def("get_nb_logic_ops", &StaticAnalysis::getNbLogicOps) .def("get_nb_comp_ops", &StaticAnalysis::getNbCompOps) diff --git a/src/graph/StaticAnalysis.cpp b/src/graph/StaticAnalysis.cpp index ddec32505739e897e6787d50987d9f4b3dec323d..453b107475c0cc6a5e127c03861d8cc4361f8734 100644 --- a/src/graph/StaticAnalysis.cpp +++ b/src/graph/StaticAnalysis.cpp @@ -17,42 +17,6 @@ Aidge::OperatorStats::OperatorStats(const Operator& op) //ctor } -size_t Aidge::OperatorStats::getNbParams() const { - return (getNbFixedParams() + getNbTrainableParams()); -} - -size_t Aidge::OperatorStats::getNbTrainableParams() const { - size_t nbParams = 0; - const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp); - if (opTensor) { - for (size_t i = 0; i < mOp.nbInputs(); ++i) { - if ((mOp.inputCategory(i) == InputCategory::Param - || mOp.inputCategory(i) == InputCategory::OptionalParam) - && opTensor->getInput(i)) - { - nbParams += opTensor->getInput(i)->size(); - } - } - } - return nbParams; -} - -size_t Aidge::OperatorStats::getParamsSize() const { - size_t paramsSize = 0; - const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp); - if (opTensor) { - for (size_t i = 0; i < mOp.nbInputs(); ++i) { - if ((mOp.inputCategory(i) == InputCategory::Param - || mOp.inputCategory(i) == InputCategory::OptionalParam) - && opTensor->getInput(i)) - { - paramsSize += opTensor->getInput(i)->size() * getDataTypeBitWidth(opTensor->getInput(i)->dataType()); - } - } - } - return paramsSize; -} - size_t Aidge::OperatorStats::getNbArithmIntOps() const { const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp); if (opTensor) { @@ -74,8 +38,7 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const { fmt::println(" Layer (type) Output Shape Param #"); fmt::println("================================================================================"); - size_t nbTrainableParams = 0; - size_t nbFixedParams = 0; + size_t nbParams = 0; size_t paramsSize = 0; size_t fwdBwdSize = 0; @@ -99,12 +62,10 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const { } } - const auto stats = getOpStats(node); - nbTrainableParams += stats->getNbTrainableParams(); - nbFixedParams += stats->getNbFixedParams(); - paramsSize += stats->getParamsSize(); + nbParams += getNbParams(node); + paramsSize += getParamsSize(node); fmt::println("{: >36}{}{: >16}", - namePtrTable.at(node), outputDimsStr, stats->getNbParams()); + namePtrTable.at(node), outputDimsStr, getNbParams(node)); } size_t inputSize = 0; @@ -118,9 +79,7 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const { } fmt::println("================================================================================"); - fmt::println("Total params: {}", nbTrainableParams + nbFixedParams); - fmt::println("Trainable params: {}", nbTrainableParams); - fmt::println("Non-trainable params: {}", nbFixedParams); + fmt::println("Total params: {}", nbParams); fmt::println("--------------------------------------------------------------------------------"); fmt::println("Input size (MB): {}", inputSize / 8.0 / 1024 / 1024); fmt::println("Forward/backward pass size (MB): {}", fwdBwdSize / 8.0 / 1024 / 1024); @@ -129,6 +88,68 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const { fmt::println("--------------------------------------------------------------------------------"); } +size_t Aidge::StaticAnalysis::getNbParams(std::shared_ptr<Node> node) const { + const auto opTensor = std::dynamic_pointer_cast<OperatorTensor>(node->getOperator()); + + size_t nbParams = 0; + + // Look for Producers directly attached to the node's inputs. + size_t i = 0; + for (auto parent : node->inputs()) { + if (parent.first && mGraph->inView(parent.first)) { + if (parent.first->type() == Producer_Op::Type && opTensor->getInput(i)) { + nbParams += opTensor->getInput(i)->size(); + } + } + ++i; + } + + // Look for internal Producers, in case of meta-op. + if (!node->getOperator()->isAtomic()) { + const auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(node->getOperator())->getMicroGraph(); + for (const auto internalNode : microGraph->getNodes()) { + if (internalNode->type() == Producer_Op::Type) { + const auto internalOpTensor = std::dynamic_pointer_cast<OperatorTensor>(internalNode->getOperator()); + nbParams += internalOpTensor->getOutput(0)->size(); + } + } + } + + return nbParams; +} + +size_t Aidge::StaticAnalysis::getParamsSize(std::shared_ptr<Node> node) const { + const auto opTensor = std::dynamic_pointer_cast<OperatorTensor>(node->getOperator()); + + size_t paramsSize = 0; + + // Look for Producers directly attached to the node's inputs. + size_t i = 0; + for (auto parent : node->inputs()) { + if (parent.first && mGraph->inView(parent.first)) { + if (parent.first->type() == Producer_Op::Type && opTensor->getInput(i)) { + paramsSize += opTensor->getInput(i)->size() + * getDataTypeBitWidth(opTensor->getInput(i)->dataType()); + } + } + ++i; + } + + // Look for internal Producers, in case of meta-op. + if (!node->getOperator()->isAtomic()) { + const auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(node->getOperator())->getMicroGraph(); + for (const auto internalNode : microGraph->getNodes()) { + if (internalNode->type() == Producer_Op::Type) { + const auto internalOpTensor = std::dynamic_pointer_cast<OperatorTensor>(internalNode->getOperator()); + paramsSize += internalOpTensor->getOutput(0)->size() + * getDataTypeBitWidth(internalOpTensor->getOutput(0)->dataType()); + } + } + } + + return paramsSize; +} + std::shared_ptr<Aidge::OperatorStats> Aidge::StaticAnalysis::getOpStats(std::shared_ptr<Node> node) const { return (Registrar<OperatorStats>::exists(node->type())) ? Registrar<OperatorStats>::create(node->type())(*(node->getOperator()))