Skip to content
Snippets Groups Projects
Commit a0aaa269 authored by Olivier BICHLER's avatar Olivier BICHLER Committed by Axel Farrugia
Browse files

Improved concept

parent 4daa23a0
No related branches found
No related tags found
2 merge requests!279v0.4.0,!250[Feat](Exports) Add custom options to exports
import matplotlib.pyplot as plt
import aidge_core
class StaticAnalysisExt(aidge_core.StaticAnalysis):
def log_nb_params(self, filename):
namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})");
nodes = self.get_graph().get_ordered_nodes()
names = []
values = []
for node in nodes:
if node.type() == "Producer":
continue
if node.type() in aidge_core.get_keys_OperatorStats():
stats = aidge_core.get_key_value_OperatorStats(node.type())
else:
stats = aidge_core.OperatorStats(node.get_operator())
names.append(namePtrTable[node])
values.append(stats.get_nb_params())
plt.bar(names, values)
plt.grid(axis='y')
plt.minorticks_on()
plt.grid(axis='y', which='major', linestyle='--', color='gray')
plt.grid(axis='y', which='minor', linestyle=':', color='lightgray')
plt.gca().set_axisbelow(True)
plt.xticks(rotation='vertical')
plt.title('Number of params per operator')
plt.savefig(filename, bbox_inches='tight')
......@@ -27,16 +27,18 @@ namespace Aidge {
class StaticAnalysis : public std::enable_shared_from_this<StaticAnalysis> {
public:
StaticAnalysis(std::shared_ptr<GraphView> graph);
const std::shared_ptr<GraphView> getGraph() const noexcept { return mGraph; }
virtual void summary(bool incProducers = false) const;
virtual ~StaticAnalysis() = default;
protected:
std::shared_ptr<GraphView> mGraph;
const std::shared_ptr<GraphView> mGraph;
};
class OperatorStats : public Registrable<OperatorStats, std::string, std::function<std::shared_ptr<OperatorStats>(const Operator&)>> {
public:
OperatorStats(const Operator& op);
const Operator& getOperator() const noexcept { return mOp; }
size_t getNbParams() const;
virtual size_t getNbFixedParams() const { return 0; };
virtual size_t getNbTrainableParams() const;
......
......@@ -149,8 +149,8 @@ void init_GraphView(py::module& m) {
// }
// })
.def("get_ranked_nodes", &GraphView::getRankedNodes)
.def("get_ranked_nodes_name", &GraphView::getRankedNodesName, py::arg("format"), py::arg("mark_non_unicity") = true)
.def("set_dataformat", &GraphView::setDataFormat, py::arg("dataformat"))
;
m.def("get_connected_graph_view", &getConnectedGraphView);
......
......@@ -115,6 +115,7 @@ public:
void init_StaticAnalysis(py::module& m){
py::class_<OperatorStats, std::shared_ptr<OperatorStats>, pyOperatorStats>(m, "OperatorStats", py::dynamic_attr())
.def(py::init<const Operator&>(), py::arg("op"))
.def("get_operator", &OperatorStats::getOperator)
.def("get_nb_params", &OperatorStats::getNbParams)
.def("get_nb_fixed_params", &OperatorStats::getNbFixedParams)
.def("get_nb_trainable_params", &OperatorStats::getNbTrainableParams)
......@@ -130,6 +131,7 @@ void init_StaticAnalysis(py::module& m){
py::class_<StaticAnalysis, std::shared_ptr<StaticAnalysis>, pyStaticAnalysis>(m, "StaticAnalysis", py::dynamic_attr())
.def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph"))
.def("get_graph", &StaticAnalysis::getGraph)
.def("summary", &StaticAnalysis::summary, py::arg("inc_producers") = false)
;
}
......
......@@ -114,9 +114,9 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const {
fmt::println("Trainable params: {}", nbTrainableParams);
fmt::println("Non-trainable params: {}", nbFixedParams);
fmt::println("--------------------------------------------------------------------------------");
fmt::println("Input size (MB): {}", inputSize / 8 / 1024 / 1024);
fmt::println("Forward/backward pass size (MB): {}", fwdBwdSize / 8 / 1024 / 1024);
fmt::println("Params size (MB): {}", paramsSize / 8 / 1024 / 1024);
fmt::println("Estimated Total Size (MB): {}", (inputSize + fwdBwdSize + paramsSize) / 8 / 1024 / 1024);
fmt::println("Input size (MB): {}", inputSize / 8.0 / 1024 / 1024);
fmt::println("Forward/backward pass size (MB): {}", fwdBwdSize / 8.0 / 1024 / 1024);
fmt::println("Params size (MB): {}", paramsSize / 8.0 / 1024 / 1024);
fmt::println("Estimated Total Size (MB): {}", (inputSize + fwdBwdSize + paramsSize) / 8.0 / 1024 / 1024);
fmt::println("--------------------------------------------------------------------------------");
}
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment