Skip to content
Snippets Groups Projects
Commit 5d9607d2 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Improved concept

parent 8249d798
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!219Initial version of hybrid C++/Python static analysis
Pipeline #56510 passed
import matplotlib.pyplot as plt
import aidge_core
class StaticAnalysisExt(aidge_core.StaticAnalysis):
def log_nb_params(self, filename):
namePtrTable = self.get_graph().get_ranked_nodes_name("{0} ({1}#{3})");
nodes = self.get_graph().get_ordered_nodes()
names = []
values = []
for node in nodes:
if node.type() == "Producer":
continue
if node.type() in aidge_core.get_keys_OperatorStats():
stats = aidge_core.get_key_value_OperatorStats(node.type())
else:
stats = aidge_core.OperatorStats(node.get_operator())
names.append(namePtrTable[node])
values.append(stats.get_nb_params())
plt.bar(names, values)
plt.grid(axis='y')
plt.minorticks_on()
plt.grid(axis='y', which='major', linestyle='--', color='gray')
plt.grid(axis='y', which='minor', linestyle=':', color='lightgray')
plt.gca().set_axisbelow(True)
plt.xticks(rotation='vertical')
plt.title('Number of params per operator')
plt.savefig(filename, bbox_inches='tight')
......@@ -27,16 +27,18 @@ namespace Aidge {
class StaticAnalysis : public std::enable_shared_from_this<StaticAnalysis> {
public:
StaticAnalysis(std::shared_ptr<GraphView> graph);
const std::shared_ptr<GraphView> getGraph() const noexcept { return mGraph; }
virtual void summary(bool incProducers = false) const;
virtual ~StaticAnalysis() = default;
protected:
std::shared_ptr<GraphView> mGraph;
const std::shared_ptr<GraphView> mGraph;
};
class OperatorStats : public Registrable<OperatorStats, std::string, std::function<std::shared_ptr<OperatorStats>(const Operator&)>> {
public:
OperatorStats(const Operator& op);
const Operator& getOperator() const noexcept { return mOp; }
size_t getNbParams() const;
virtual size_t getNbFixedParams() const { return 0; };
virtual size_t getNbTrainableParams() const;
......
......@@ -147,8 +147,8 @@ void init_GraphView(py::module& m) {
// }
// })
.def("get_ranked_nodes", &GraphView::getRankedNodes)
.def("get_ranked_nodes_name", &GraphView::getRankedNodesName, py::arg("format"), py::arg("mark_non_unicity") = true)
.def("set_dataformat", &GraphView::setDataFormat, py::arg("dataformat"))
;
m.def("get_connected_graph_view", &getConnectedGraphView);
......
......@@ -115,6 +115,7 @@ public:
void init_StaticAnalysis(py::module& m){
py::class_<OperatorStats, std::shared_ptr<OperatorStats>, pyOperatorStats>(m, "OperatorStats", py::dynamic_attr())
.def(py::init<const Operator&>(), py::arg("op"))
.def("get_operator", &OperatorStats::getOperator)
.def("get_nb_params", &OperatorStats::getNbParams)
.def("get_nb_fixed_params", &OperatorStats::getNbFixedParams)
.def("get_nb_trainable_params", &OperatorStats::getNbTrainableParams)
......@@ -130,6 +131,7 @@ void init_StaticAnalysis(py::module& m){
py::class_<StaticAnalysis, std::shared_ptr<StaticAnalysis>, pyStaticAnalysis>(m, "StaticAnalysis", py::dynamic_attr())
.def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph"))
.def("get_graph", &StaticAnalysis::getGraph)
.def("summary", &StaticAnalysis::summary, py::arg("inc_producers") = false)
;
}
......
......@@ -114,9 +114,9 @@ void Aidge::StaticAnalysis::summary(bool incProducers) const {
fmt::println("Trainable params: {}", nbTrainableParams);
fmt::println("Non-trainable params: {}", nbFixedParams);
fmt::println("--------------------------------------------------------------------------------");
fmt::println("Input size (MB): {}", inputSize / 8 / 1024 / 1024);
fmt::println("Forward/backward pass size (MB): {}", fwdBwdSize / 8 / 1024 / 1024);
fmt::println("Params size (MB): {}", paramsSize / 8 / 1024 / 1024);
fmt::println("Estimated Total Size (MB): {}", (inputSize + fwdBwdSize + paramsSize) / 8 / 1024 / 1024);
fmt::println("Input size (MB): {}", inputSize / 8.0 / 1024 / 1024);
fmt::println("Forward/backward pass size (MB): {}", fwdBwdSize / 8.0 / 1024 / 1024);
fmt::println("Params size (MB): {}", paramsSize / 8.0 / 1024 / 1024);
fmt::println("Estimated Total Size (MB): {}", (inputSize + fwdBwdSize + paramsSize) / 8.0 / 1024 / 1024);
fmt::println("--------------------------------------------------------------------------------");
}
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment