Skip to content
Snippets Groups Projects
Commit 9ebe0eaa authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed parameter count

parent 034e5f83
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!219Initial version of hybrid C++/Python static analysis
Pipeline #58966 failed
......@@ -6,16 +6,36 @@ import aidge_core
class StaticAnalysisExt(aidge_core.StaticAnalysis):
def log_nb_params(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_params, filename, title, log_scale)
namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})");
nodes = self.get_graph().get_ordered_nodes()
series = []
legend = None
def log_nb_fixed_params(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_fixed_params, filename, title, log_scale)
for node in nodes:
if node.type() == "Producer":
continue
def log_nb_trainable_params(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_trainable_params, filename, title, log_scale)
name = namePtrTable[node]
series.append([name, self.get_nb_params(node)])
if title is None: title = "log_nb_params"
self._log_bar(series, filename, title, legend, log_scale)
def log_params_size(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_params_size, filename, title, log_scale)
namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})");
nodes = self.get_graph().get_ordered_nodes()
series = []
legend = None
for node in nodes:
if node.type() == "Producer":
continue
name = namePtrTable[node]
series.append([name, self.log_params_size(node)])
if title is None: title = "log_params_size"
self._log_bar(series, filename, title, legend, log_scale)
def log_nb_arithm_ops(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_arithm_ops, filename, title, log_scale)
......
......@@ -33,17 +33,13 @@
namespace Aidge {
/**
* @brief Base class to compute statistics from an Operator
* @brief Base class to compute statistics from an Operator.
*
*/
class OperatorStats : public Registrable<OperatorStats, std::string, std::function<std::shared_ptr<OperatorStats>(const Operator&)>> {
public:
OperatorStats(const Operator& op);
const Operator& getOperator() const noexcept { return mOp; }
size_t getNbParams() const;
virtual size_t getNbFixedParams() const { return 0; };
virtual size_t getNbTrainableParams() const;
virtual size_t getParamsSize() const;
/**
* @brief Get the total number of arithmetic operations for the operator.
......@@ -126,10 +122,35 @@ class StaticAnalysis : public std::enable_shared_from_this<StaticAnalysis> {
public:
StaticAnalysis(std::shared_ptr<GraphView> graph);
const std::shared_ptr<GraphView> getGraph() const noexcept { return mGraph; }
size_t getNbParams() const { return accumulate(&OperatorStats::getNbParams); }
size_t getNbFixedParams() const { return accumulate(&OperatorStats::getNbFixedParams); }
size_t getNbTrainableParams() const { return accumulate(&OperatorStats::getNbTrainableParams); }
size_t getParamsSize() const { return accumulate(&OperatorStats::getParamsSize); }
/**
* @brief Get the number of parameters associated to a node. This includes
* all Producers directly connected to the node's inputs as well as all
* internal Producers (in case of a meta operator).
*
* Note: this function does not check if parameters are shared between
* several nodes or not. This means that simply adding parameters count from
* several nodes may lead to a higher number of parameters than in reality.
*
* @param node Node
* @return size_t Number of parameters
*/
virtual size_t getNbParams(std::shared_ptr<Node> node) const;
/**
* @brief Get the total parameters memory size, in bits, associated to a node.
* This includes all Producers directly connected to the node's inputs as
* well as all internal Producers (in case of a meta operator).
*
* Note: this function does not check if parameters are shared between
* several nodes or not. This means that simply adding parameters size from
* several nodes may lead to a higher parameter size than in reality.
*
* @param node Node
* @return size_t Total parameters memory, in bits
*/
virtual size_t getParamsSize(std::shared_ptr<Node> node) const;
size_t getNbArithmOps() const { return accumulate(&OperatorStats::getNbArithmOps); }
size_t getNbLogicOps() const { return accumulate(&OperatorStats::getNbLogicOps); }
size_t getNbCompOps() const { return accumulate(&OperatorStats::getNbCompOps); }
......@@ -158,9 +179,6 @@ public:
return std::make_unique<MetaOpStats>(op);
}
size_t getNbFixedParams() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbFixedParams(); }
size_t getNbTrainableParams() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbTrainableParams(); }
size_t getParamsSize() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getParamsSize(); }
size_t getNbArithmOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmOps(); }
size_t getNbLogicOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbLogicOps(); }
size_t getNbCompOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbCompOps(); }
......
......@@ -25,30 +25,6 @@ class pyOperatorStats: public OperatorStats {
public:
using OperatorStats::OperatorStats; // Inherit constructors
size_t getNbFixedParams() const override {
PYBIND11_OVERRIDE(
size_t,
OperatorStats,
getNbFixedParams
);
}
size_t getNbTrainableParams() const override {
PYBIND11_OVERRIDE(
size_t,
OperatorStats,
getNbTrainableParams
);
}
size_t getParamsSize() const override {
PYBIND11_OVERRIDE(
size_t,
OperatorStats,
getParamsSize
);
}
size_t getNbArithmOps() const override {
PYBIND11_OVERRIDE(
size_t,
......@@ -102,6 +78,24 @@ class pyStaticAnalysis: public StaticAnalysis {
public:
using StaticAnalysis::StaticAnalysis; // Inherit constructors
size_t getNbParams(std::shared_ptr<Node> node) const override {
PYBIND11_OVERRIDE(
size_t,
StaticAnalysis,
getNbParams,
node
);
}
size_t getParamsSize(std::shared_ptr<Node> node) const override {
PYBIND11_OVERRIDE(
size_t,
StaticAnalysis,
getParamsSize,
node
);
}
void summary(bool incProducers) const override {
PYBIND11_OVERRIDE(
void,
......@@ -122,10 +116,6 @@ void init_StaticAnalysis(py::module& m){
py::class_<OperatorStats, std::shared_ptr<OperatorStats>, pyOperatorStats>(m, "OperatorStats", py::multiple_inheritance(), py::dynamic_attr())
.def(py::init<const Operator&>(), py::arg("op"))
.def("get_operator", &OperatorStats::getOperator)
.def("get_nb_params", &OperatorStats::getNbParams)
.def("get_nb_fixed_params", &OperatorStats::getNbFixedParams)
.def("get_nb_trainable_params", &OperatorStats::getNbTrainableParams)
.def("get_params_size", &OperatorStats::getParamsSize)
.def("get_nb_arithm_ops", &OperatorStats::getNbArithmOps)
.def("get_nb_logic_ops", &OperatorStats::getNbLogicOps)
.def("get_nb_comp_ops", &OperatorStats::getNbCompOps)
......@@ -140,10 +130,8 @@ void init_StaticAnalysis(py::module& m){
py::class_<StaticAnalysis, std::shared_ptr<StaticAnalysis>, pyStaticAnalysis>(m, "StaticAnalysis", py::multiple_inheritance(), py::dynamic_attr())
.def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph"))
.def("get_graph", &StaticAnalysis::getGraph)
.def("get_nb_params", &StaticAnalysis::getNbParams)
.def("get_nb_fixed_params", &StaticAnalysis::getNbFixedParams)
.def("get_nb_trainable_params", &StaticAnalysis::getNbTrainableParams)
.def("get_params_size", &StaticAnalysis::getParamsSize)
.def("get_nb_params", &StaticAnalysis::getNbParams, py::arg("node"))
.def("get_params_size", &StaticAnalysis::getParamsSize, py::arg("node"))
.def("get_nb_arithm_ops", &StaticAnalysis::getNbArithmOps)
.def("get_nb_logic_ops", &StaticAnalysis::getNbLogicOps)
.def("get_nb_comp_ops", &StaticAnalysis::getNbCompOps)
......
......@@ -17,42 +17,6 @@ Aidge::OperatorStats::OperatorStats(const Operator& op)
//ctor
}
size_t Aidge::OperatorStats::getNbParams() const {
return (getNbFixedParams() + getNbTrainableParams());
}
size_t Aidge::OperatorStats::getNbTrainableParams() const {
size_t nbParams = 0;
const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp);
if (opTensor) {
for (size_t i = 0; i < mOp.nbInputs(); ++i) {
if ((mOp.inputCategory(i) == InputCategory::Param
|| mOp.inputCategory(i) == InputCategory::OptionalParam)
&& opTensor->getInput(i))
{
nbParams += opTensor->getInput(i)->size();
}
}
}
return nbParams;
}
size_t Aidge::OperatorStats::getParamsSize() const {
size_t paramsSize = 0;
const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp);
if (opTensor) {
for (size_t i = 0; i < mOp.nbInputs(); ++i) {
if ((mOp.inputCategory(i) == InputCategory::Param
|| mOp.inputCategory(i) == InputCategory::OptionalParam)
&& opTensor->getInput(i))
{
paramsSize += opTensor->getInput(i)->size() * getDataTypeBitWidth(opTensor->getInput(i)->dataType());
}
}
}
return paramsSize;
}
size_t Aidge::OperatorStats::getNbArithmIntOps() const {
const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp);
if (opTensor) {
......@@ -74,8 +38,7 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const {
fmt::println(" Layer (type) Output Shape Param #");
fmt::println("================================================================================");
size_t nbTrainableParams = 0;
size_t nbFixedParams = 0;
size_t nbParams = 0;
size_t paramsSize = 0;
size_t fwdBwdSize = 0;
......@@ -99,12 +62,10 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const {
}
}
const auto stats = getOpStats(node);
nbTrainableParams += stats->getNbTrainableParams();
nbFixedParams += stats->getNbFixedParams();
paramsSize += stats->getParamsSize();
nbParams += getNbParams(node);
paramsSize += getParamsSize(node);
fmt::println("{: >36}{}{: >16}",
namePtrTable.at(node), outputDimsStr, stats->getNbParams());
namePtrTable.at(node), outputDimsStr, getNbParams(node));
}
size_t inputSize = 0;
......@@ -118,9 +79,7 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const {
}
fmt::println("================================================================================");
fmt::println("Total params: {}", nbTrainableParams + nbFixedParams);
fmt::println("Trainable params: {}", nbTrainableParams);
fmt::println("Non-trainable params: {}", nbFixedParams);
fmt::println("Total params: {}", nbParams);
fmt::println("--------------------------------------------------------------------------------");
fmt::println("Input size (MB): {}", inputSize / 8.0 / 1024 / 1024);
fmt::println("Forward/backward pass size (MB): {}", fwdBwdSize / 8.0 / 1024 / 1024);
......@@ -129,6 +88,68 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const {
fmt::println("--------------------------------------------------------------------------------");
}
size_t Aidge::StaticAnalysis::getNbParams(std::shared_ptr<Node> node) const {
const auto opTensor = std::dynamic_pointer_cast<OperatorTensor>(node->getOperator());
size_t nbParams = 0;
// Look for Producers directly attached to the node's inputs.
size_t i = 0;
for (auto parent : node->inputs()) {
if (parent.first && mGraph->inView(parent.first)) {
if (parent.first->type() == Producer_Op::Type && opTensor->getInput(i)) {
nbParams += opTensor->getInput(i)->size();
}
}
++i;
}
// Look for internal Producers, in case of meta-op.
if (!node->getOperator()->isAtomic()) {
const auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(node->getOperator())->getMicroGraph();
for (const auto internalNode : microGraph->getNodes()) {
if (internalNode->type() == Producer_Op::Type) {
const auto internalOpTensor = std::dynamic_pointer_cast<OperatorTensor>(internalNode->getOperator());
nbParams += internalOpTensor->getOutput(0)->size();
}
}
}
return nbParams;
}
size_t Aidge::StaticAnalysis::getParamsSize(std::shared_ptr<Node> node) const {
const auto opTensor = std::dynamic_pointer_cast<OperatorTensor>(node->getOperator());
size_t paramsSize = 0;
// Look for Producers directly attached to the node's inputs.
size_t i = 0;
for (auto parent : node->inputs()) {
if (parent.first && mGraph->inView(parent.first)) {
if (parent.first->type() == Producer_Op::Type && opTensor->getInput(i)) {
paramsSize += opTensor->getInput(i)->size()
* getDataTypeBitWidth(opTensor->getInput(i)->dataType());
}
}
++i;
}
// Look for internal Producers, in case of meta-op.
if (!node->getOperator()->isAtomic()) {
const auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(node->getOperator())->getMicroGraph();
for (const auto internalNode : microGraph->getNodes()) {
if (internalNode->type() == Producer_Op::Type) {
const auto internalOpTensor = std::dynamic_pointer_cast<OperatorTensor>(internalNode->getOperator());
paramsSize += internalOpTensor->getOutput(0)->size()
* getDataTypeBitWidth(internalOpTensor->getOutput(0)->dataType());
}
}
}
return paramsSize;
}
std::shared_ptr<Aidge::OperatorStats> Aidge::StaticAnalysis::getOpStats(std::shared_ptr<Node> node) const {
return (Registrar<OperatorStats>::exists(node->type()))
? Registrar<OperatorStats>::create(node->type())(*(node->getOperator()))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment