Skip to content
Snippets Groups Projects
Commit 035cade1 authored by Olivier BICHLER's avatar Olivier BICHLER Committed by Axel Farrugia
Browse files

Fixed parameter count

parent 9675715a
No related branches found
No related tags found
2 merge requests!279v0.4.0,!250[Feat](Exports) Add custom options to exports
...@@ -6,16 +6,36 @@ import aidge_core ...@@ -6,16 +6,36 @@ import aidge_core
class StaticAnalysisExt(aidge_core.StaticAnalysis): class StaticAnalysisExt(aidge_core.StaticAnalysis):
def log_nb_params(self, filename, title=None, log_scale=False): def log_nb_params(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_params, filename, title, log_scale) namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})");
nodes = self.get_graph().get_ordered_nodes()
series = []
legend = None
def log_nb_fixed_params(self, filename, title=None, log_scale=False): for node in nodes:
self._log_callback(aidge_core.OperatorStats.get_nb_fixed_params, filename, title, log_scale) if node.type() == "Producer":
continue
def log_nb_trainable_params(self, filename, title=None, log_scale=False): name = namePtrTable[node]
self._log_callback(aidge_core.OperatorStats.get_nb_trainable_params, filename, title, log_scale) series.append([name, self.get_nb_params(node)])
if title is None: title = "log_nb_params"
self._log_bar(series, filename, title, legend, log_scale)
def log_params_size(self, filename, title=None, log_scale=False): def log_params_size(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_params_size, filename, title, log_scale) namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})");
nodes = self.get_graph().get_ordered_nodes()
series = []
legend = None
for node in nodes:
if node.type() == "Producer":
continue
name = namePtrTable[node]
series.append([name, self.log_params_size(node)])
if title is None: title = "log_params_size"
self._log_bar(series, filename, title, legend, log_scale)
def log_nb_arithm_ops(self, filename, title=None, log_scale=False): def log_nb_arithm_ops(self, filename, title=None, log_scale=False):
self._log_callback(aidge_core.OperatorStats.get_nb_arithm_ops, filename, title, log_scale) self._log_callback(aidge_core.OperatorStats.get_nb_arithm_ops, filename, title, log_scale)
......
...@@ -33,17 +33,13 @@ ...@@ -33,17 +33,13 @@
namespace Aidge { namespace Aidge {
/** /**
* @brief Base class to compute statistics from an Operator * @brief Base class to compute statistics from an Operator.
* *
*/ */
class OperatorStats : public Registrable<OperatorStats, std::string, std::function<std::shared_ptr<OperatorStats>(const Operator&)>> { class OperatorStats : public Registrable<OperatorStats, std::string, std::function<std::shared_ptr<OperatorStats>(const Operator&)>> {
public: public:
OperatorStats(const Operator& op); OperatorStats(const Operator& op);
const Operator& getOperator() const noexcept { return mOp; } const Operator& getOperator() const noexcept { return mOp; }
size_t getNbParams() const;
virtual size_t getNbFixedParams() const { return 0; };
virtual size_t getNbTrainableParams() const;
virtual size_t getParamsSize() const;
/** /**
* @brief Get the total number of arithmetic operations for the operator. * @brief Get the total number of arithmetic operations for the operator.
...@@ -126,10 +122,35 @@ class StaticAnalysis : public std::enable_shared_from_this<StaticAnalysis> { ...@@ -126,10 +122,35 @@ class StaticAnalysis : public std::enable_shared_from_this<StaticAnalysis> {
public: public:
StaticAnalysis(std::shared_ptr<GraphView> graph); StaticAnalysis(std::shared_ptr<GraphView> graph);
const std::shared_ptr<GraphView> getGraph() const noexcept { return mGraph; } const std::shared_ptr<GraphView> getGraph() const noexcept { return mGraph; }
size_t getNbParams() const { return accumulate(&OperatorStats::getNbParams); }
size_t getNbFixedParams() const { return accumulate(&OperatorStats::getNbFixedParams); } /**
size_t getNbTrainableParams() const { return accumulate(&OperatorStats::getNbTrainableParams); } * @brief Get the number of parameters associated to a node. This includes
size_t getParamsSize() const { return accumulate(&OperatorStats::getParamsSize); } * all Producers directly connected to the node's inputs as well as all
* internal Producers (in case of a meta operator).
*
* Note: this function does not check if parameters are shared between
* several nodes or not. This means that simply adding parameters count from
* several nodes may lead to a higher number of parameters than in reality.
*
* @param node Node
* @return size_t Number of parameters
*/
virtual size_t getNbParams(std::shared_ptr<Node> node) const;
/**
* @brief Get the total parameters memory size, in bits, associated to a node.
* This includes all Producers directly connected to the node's inputs as
* well as all internal Producers (in case of a meta operator).
*
* Note: this function does not check if parameters are shared between
* several nodes or not. This means that simply adding parameters size from
* several nodes may lead to a higher parameter size than in reality.
*
* @param node Node
* @return size_t Total parameters memory, in bits
*/
virtual size_t getParamsSize(std::shared_ptr<Node> node) const;
size_t getNbArithmOps() const { return accumulate(&OperatorStats::getNbArithmOps); } size_t getNbArithmOps() const { return accumulate(&OperatorStats::getNbArithmOps); }
size_t getNbLogicOps() const { return accumulate(&OperatorStats::getNbLogicOps); } size_t getNbLogicOps() const { return accumulate(&OperatorStats::getNbLogicOps); }
size_t getNbCompOps() const { return accumulate(&OperatorStats::getNbCompOps); } size_t getNbCompOps() const { return accumulate(&OperatorStats::getNbCompOps); }
...@@ -158,9 +179,6 @@ public: ...@@ -158,9 +179,6 @@ public:
return std::make_unique<MetaOpStats>(op); return std::make_unique<MetaOpStats>(op);
} }
size_t getNbFixedParams() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbFixedParams(); }
size_t getNbTrainableParams() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbTrainableParams(); }
size_t getParamsSize() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getParamsSize(); }
size_t getNbArithmOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmOps(); } size_t getNbArithmOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbArithmOps(); }
size_t getNbLogicOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbLogicOps(); } size_t getNbLogicOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbLogicOps(); }
size_t getNbCompOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbCompOps(); } size_t getNbCompOps() const override { return StaticAnalysis(dynamic_cast<const MetaOperator_Op&>(mOp).getMicroGraph()).getNbCompOps(); }
......
...@@ -25,30 +25,6 @@ class pyOperatorStats: public OperatorStats { ...@@ -25,30 +25,6 @@ class pyOperatorStats: public OperatorStats {
public: public:
using OperatorStats::OperatorStats; // Inherit constructors using OperatorStats::OperatorStats; // Inherit constructors
size_t getNbFixedParams() const override {
PYBIND11_OVERRIDE(
size_t,
OperatorStats,
getNbFixedParams
);
}
size_t getNbTrainableParams() const override {
PYBIND11_OVERRIDE(
size_t,
OperatorStats,
getNbTrainableParams
);
}
size_t getParamsSize() const override {
PYBIND11_OVERRIDE(
size_t,
OperatorStats,
getParamsSize
);
}
size_t getNbArithmOps() const override { size_t getNbArithmOps() const override {
PYBIND11_OVERRIDE( PYBIND11_OVERRIDE(
size_t, size_t,
...@@ -102,6 +78,24 @@ class pyStaticAnalysis: public StaticAnalysis { ...@@ -102,6 +78,24 @@ class pyStaticAnalysis: public StaticAnalysis {
public: public:
using StaticAnalysis::StaticAnalysis; // Inherit constructors using StaticAnalysis::StaticAnalysis; // Inherit constructors
size_t getNbParams(std::shared_ptr<Node> node) const override {
PYBIND11_OVERRIDE(
size_t,
StaticAnalysis,
getNbParams,
node
);
}
size_t getParamsSize(std::shared_ptr<Node> node) const override {
PYBIND11_OVERRIDE(
size_t,
StaticAnalysis,
getParamsSize,
node
);
}
void summary(bool incProducers) const override { void summary(bool incProducers) const override {
PYBIND11_OVERRIDE( PYBIND11_OVERRIDE(
void, void,
...@@ -122,10 +116,6 @@ void init_StaticAnalysis(py::module& m){ ...@@ -122,10 +116,6 @@ void init_StaticAnalysis(py::module& m){
py::class_<OperatorStats, std::shared_ptr<OperatorStats>, pyOperatorStats>(m, "OperatorStats", py::multiple_inheritance(), py::dynamic_attr()) py::class_<OperatorStats, std::shared_ptr<OperatorStats>, pyOperatorStats>(m, "OperatorStats", py::multiple_inheritance(), py::dynamic_attr())
.def(py::init<const Operator&>(), py::arg("op")) .def(py::init<const Operator&>(), py::arg("op"))
.def("get_operator", &OperatorStats::getOperator) .def("get_operator", &OperatorStats::getOperator)
.def("get_nb_params", &OperatorStats::getNbParams)
.def("get_nb_fixed_params", &OperatorStats::getNbFixedParams)
.def("get_nb_trainable_params", &OperatorStats::getNbTrainableParams)
.def("get_params_size", &OperatorStats::getParamsSize)
.def("get_nb_arithm_ops", &OperatorStats::getNbArithmOps) .def("get_nb_arithm_ops", &OperatorStats::getNbArithmOps)
.def("get_nb_logic_ops", &OperatorStats::getNbLogicOps) .def("get_nb_logic_ops", &OperatorStats::getNbLogicOps)
.def("get_nb_comp_ops", &OperatorStats::getNbCompOps) .def("get_nb_comp_ops", &OperatorStats::getNbCompOps)
...@@ -140,10 +130,8 @@ void init_StaticAnalysis(py::module& m){ ...@@ -140,10 +130,8 @@ void init_StaticAnalysis(py::module& m){
py::class_<StaticAnalysis, std::shared_ptr<StaticAnalysis>, pyStaticAnalysis>(m, "StaticAnalysis", py::multiple_inheritance(), py::dynamic_attr()) py::class_<StaticAnalysis, std::shared_ptr<StaticAnalysis>, pyStaticAnalysis>(m, "StaticAnalysis", py::multiple_inheritance(), py::dynamic_attr())
.def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph")) .def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph"))
.def("get_graph", &StaticAnalysis::getGraph) .def("get_graph", &StaticAnalysis::getGraph)
.def("get_nb_params", &StaticAnalysis::getNbParams) .def("get_nb_params", &StaticAnalysis::getNbParams, py::arg("node"))
.def("get_nb_fixed_params", &StaticAnalysis::getNbFixedParams) .def("get_params_size", &StaticAnalysis::getParamsSize, py::arg("node"))
.def("get_nb_trainable_params", &StaticAnalysis::getNbTrainableParams)
.def("get_params_size", &StaticAnalysis::getParamsSize)
.def("get_nb_arithm_ops", &StaticAnalysis::getNbArithmOps) .def("get_nb_arithm_ops", &StaticAnalysis::getNbArithmOps)
.def("get_nb_logic_ops", &StaticAnalysis::getNbLogicOps) .def("get_nb_logic_ops", &StaticAnalysis::getNbLogicOps)
.def("get_nb_comp_ops", &StaticAnalysis::getNbCompOps) .def("get_nb_comp_ops", &StaticAnalysis::getNbCompOps)
......
...@@ -17,42 +17,6 @@ Aidge::OperatorStats::OperatorStats(const Operator& op) ...@@ -17,42 +17,6 @@ Aidge::OperatorStats::OperatorStats(const Operator& op)
//ctor //ctor
} }
size_t Aidge::OperatorStats::getNbParams() const {
return (getNbFixedParams() + getNbTrainableParams());
}
size_t Aidge::OperatorStats::getNbTrainableParams() const {
size_t nbParams = 0;
const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp);
if (opTensor) {
for (size_t i = 0; i < mOp.nbInputs(); ++i) {
if ((mOp.inputCategory(i) == InputCategory::Param
|| mOp.inputCategory(i) == InputCategory::OptionalParam)
&& opTensor->getInput(i))
{
nbParams += opTensor->getInput(i)->size();
}
}
}
return nbParams;
}
size_t Aidge::OperatorStats::getParamsSize() const {
size_t paramsSize = 0;
const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp);
if (opTensor) {
for (size_t i = 0; i < mOp.nbInputs(); ++i) {
if ((mOp.inputCategory(i) == InputCategory::Param
|| mOp.inputCategory(i) == InputCategory::OptionalParam)
&& opTensor->getInput(i))
{
paramsSize += opTensor->getInput(i)->size() * getDataTypeBitWidth(opTensor->getInput(i)->dataType());
}
}
}
return paramsSize;
}
size_t Aidge::OperatorStats::getNbArithmIntOps() const { size_t Aidge::OperatorStats::getNbArithmIntOps() const {
const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp); const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp);
if (opTensor) { if (opTensor) {
...@@ -74,8 +38,7 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const { ...@@ -74,8 +38,7 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const {
fmt::println(" Layer (type) Output Shape Param #"); fmt::println(" Layer (type) Output Shape Param #");
fmt::println("================================================================================"); fmt::println("================================================================================");
size_t nbTrainableParams = 0; size_t nbParams = 0;
size_t nbFixedParams = 0;
size_t paramsSize = 0; size_t paramsSize = 0;
size_t fwdBwdSize = 0; size_t fwdBwdSize = 0;
...@@ -99,12 +62,10 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const { ...@@ -99,12 +62,10 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const {
} }
} }
const auto stats = getOpStats(node); nbParams += getNbParams(node);
nbTrainableParams += stats->getNbTrainableParams(); paramsSize += getParamsSize(node);
nbFixedParams += stats->getNbFixedParams();
paramsSize += stats->getParamsSize();
fmt::println("{: >36}{}{: >16}", fmt::println("{: >36}{}{: >16}",
namePtrTable.at(node), outputDimsStr, stats->getNbParams()); namePtrTable.at(node), outputDimsStr, getNbParams(node));
} }
size_t inputSize = 0; size_t inputSize = 0;
...@@ -118,9 +79,7 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const { ...@@ -118,9 +79,7 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const {
} }
fmt::println("================================================================================"); fmt::println("================================================================================");
fmt::println("Total params: {}", nbTrainableParams + nbFixedParams); fmt::println("Total params: {}", nbParams);
fmt::println("Trainable params: {}", nbTrainableParams);
fmt::println("Non-trainable params: {}", nbFixedParams);
fmt::println("--------------------------------------------------------------------------------"); fmt::println("--------------------------------------------------------------------------------");
fmt::println("Input size (MB): {}", inputSize / 8.0 / 1024 / 1024); fmt::println("Input size (MB): {}", inputSize / 8.0 / 1024 / 1024);
fmt::println("Forward/backward pass size (MB): {}", fwdBwdSize / 8.0 / 1024 / 1024); fmt::println("Forward/backward pass size (MB): {}", fwdBwdSize / 8.0 / 1024 / 1024);
...@@ -129,6 +88,68 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const { ...@@ -129,6 +88,68 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const {
fmt::println("--------------------------------------------------------------------------------"); fmt::println("--------------------------------------------------------------------------------");
} }
size_t Aidge::StaticAnalysis::getNbParams(std::shared_ptr<Node> node) const {
const auto opTensor = std::dynamic_pointer_cast<OperatorTensor>(node->getOperator());
size_t nbParams = 0;
// Look for Producers directly attached to the node's inputs.
size_t i = 0;
for (auto parent : node->inputs()) {
if (parent.first && mGraph->inView(parent.first)) {
if (parent.first->type() == Producer_Op::Type && opTensor->getInput(i)) {
nbParams += opTensor->getInput(i)->size();
}
}
++i;
}
// Look for internal Producers, in case of meta-op.
if (!node->getOperator()->isAtomic()) {
const auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(node->getOperator())->getMicroGraph();
for (const auto internalNode : microGraph->getNodes()) {
if (internalNode->type() == Producer_Op::Type) {
const auto internalOpTensor = std::dynamic_pointer_cast<OperatorTensor>(internalNode->getOperator());
nbParams += internalOpTensor->getOutput(0)->size();
}
}
}
return nbParams;
}
size_t Aidge::StaticAnalysis::getParamsSize(std::shared_ptr<Node> node) const {
const auto opTensor = std::dynamic_pointer_cast<OperatorTensor>(node->getOperator());
size_t paramsSize = 0;
// Look for Producers directly attached to the node's inputs.
size_t i = 0;
for (auto parent : node->inputs()) {
if (parent.first && mGraph->inView(parent.first)) {
if (parent.first->type() == Producer_Op::Type && opTensor->getInput(i)) {
paramsSize += opTensor->getInput(i)->size()
* getDataTypeBitWidth(opTensor->getInput(i)->dataType());
}
}
++i;
}
// Look for internal Producers, in case of meta-op.
if (!node->getOperator()->isAtomic()) {
const auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(node->getOperator())->getMicroGraph();
for (const auto internalNode : microGraph->getNodes()) {
if (internalNode->type() == Producer_Op::Type) {
const auto internalOpTensor = std::dynamic_pointer_cast<OperatorTensor>(internalNode->getOperator());
paramsSize += internalOpTensor->getOutput(0)->size()
* getDataTypeBitWidth(internalOpTensor->getOutput(0)->dataType());
}
}
}
return paramsSize;
}
std::shared_ptr<Aidge::OperatorStats> Aidge::StaticAnalysis::getOpStats(std::shared_ptr<Node> node) const { std::shared_ptr<Aidge::OperatorStats> Aidge::StaticAnalysis::getOpStats(std::shared_ptr<Node> node) const {
return (Registrar<OperatorStats>::exists(node->type())) return (Registrar<OperatorStats>::exists(node->type()))
? Registrar<OperatorStats>::create(node->type())(*(node->getOperator())) ? Registrar<OperatorStats>::create(node->type())(*(node->getOperator()))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment