-
Olivier BICHLER authoredOlivier BICHLER authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
StaticAnalysis.cpp 5.74 KiB
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/graph/StaticAnalysis.hpp"
Aidge::OperatorStats::OperatorStats(const Operator& op)
: mOp(op)
{
//ctor
}
size_t Aidge::OperatorStats::getNbParams() const {
return (getNbFixedParams() + getNbTrainableParams());
}
size_t Aidge::OperatorStats::getNbTrainableParams() const {
size_t nbParams = 0;
const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp);
if (opTensor) {
for (size_t i = 0; i < mOp.nbInputs(); ++i) {
if ((mOp.inputCategory(i) == InputCategory::Param
|| mOp.inputCategory(i) == InputCategory::OptionalParam)
&& opTensor->getInput(i))
{
nbParams += opTensor->getInput(i)->size();
}
}
}
return nbParams;
}
size_t Aidge::OperatorStats::getParamsSize() const {
size_t paramsSize = 0;
const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp);
if (opTensor) {
for (size_t i = 0; i < mOp.nbInputs(); ++i) {
if ((mOp.inputCategory(i) == InputCategory::Param
|| mOp.inputCategory(i) == InputCategory::OptionalParam)
&& opTensor->getInput(i))
{
paramsSize += opTensor->getInput(i)->size() * getDataTypeBitWidth(opTensor->getInput(i)->dataType());
}
}
}
return paramsSize;
}
size_t Aidge::OperatorStats::getNbArithmIntOps() const {
const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp);
if (opTensor) {
if (!isDataTypeFloatingPoint(opTensor->getOutput(0)->dataType())) {
return getNbArithmOps();
}
}
return 0;
}
Aidge::StaticAnalysis::StaticAnalysis(std::shared_ptr<GraphView> graph)
: mGraph(graph)
{
//ctor
}
void Aidge::StaticAnalysis::summary(bool incProducers) const {
fmt::println("--------------------------------------------------------------------------------");
fmt::println(" Layer (type) Output Shape Param #");
fmt::println("================================================================================");
size_t nbTrainableParams = 0;
size_t nbFixedParams = 0;
size_t paramsSize = 0;
size_t fwdBwdSize = 0;
const auto namePtrTable = mGraph->getRankedNodesName("{0} ({1}#{3})");
for (const auto node : mGraph->getOrderedNodes()) {
if (node->type() == Producer_Op::Type && !incProducers) {
continue;
}
auto opTensor = std::dynamic_pointer_cast<OperatorTensor>(node->getOperator());
std::string outputDimsStr = fmt::format("{: >27}", "?");
if (opTensor) {
const auto outputDims = opTensor->getOutput(0)->dims();
outputDimsStr = fmt::format("{: >27}", fmt::format("{}", outputDims));
for (size_t out = 0; out < node->nbOutputs(); ++out) {
const auto output = opTensor->getOutput(out);
if (output && node->type() != Producer_Op::Type) {
fwdBwdSize += output->size();
}
}
}
const auto stats = getOpStats(node);
nbTrainableParams += stats->getNbTrainableParams();
nbFixedParams += stats->getNbFixedParams();
paramsSize += stats->getParamsSize();
fmt::println("{: >36}{}{: >16}",
namePtrTable.at(node), outputDimsStr, stats->getNbParams());
}
size_t inputSize = 0;
for (const auto input : mGraph->getOrderedInputs()) {
if (input.first) {
auto opTensor = std::dynamic_pointer_cast<OperatorTensor>(input.first->getOperator());
if (opTensor && opTensor->getInput(input.second)) {
inputSize += opTensor->getInput(input.second)->size();
}
}
}
fmt::println("================================================================================");
fmt::println("Total params: {}", nbTrainableParams + nbFixedParams);
fmt::println("Trainable params: {}", nbTrainableParams);
fmt::println("Non-trainable params: {}", nbFixedParams);
fmt::println("--------------------------------------------------------------------------------");
fmt::println("Input size (MB): {}", inputSize / 8.0 / 1024 / 1024);
fmt::println("Forward/backward pass size (MB): {}", fwdBwdSize / 8.0 / 1024 / 1024);
fmt::println("Params size (MB): {}", paramsSize / 8.0 / 1024 / 1024);
fmt::println("Estimated Total Size (MB): {}", (inputSize + fwdBwdSize + paramsSize) / 8.0 / 1024 / 1024);
fmt::println("--------------------------------------------------------------------------------");
}
std::shared_ptr<Aidge::OperatorStats> Aidge::StaticAnalysis::getOpStats(std::shared_ptr<Node> node) const {
return (Registrar<OperatorStats>::exists(node->type()))
? Registrar<OperatorStats>::create(node->type())(*(node->getOperator()))
: (node->getOperator()->isAtomic())
? std::make_shared<OperatorStats>(*(node->getOperator()))
: std::make_shared<MetaOpStats>(*(node->getOperator()));
}
size_t Aidge::StaticAnalysis::accumulate(size_t (OperatorStats::*func)() const) const {
return std::accumulate(
mGraph->getNodes().cbegin(),
mGraph->getNodes().cend(),
std::size_t(0),
[this, func](const size_t& lhs, const std::shared_ptr<Node>& rhs) {
return lhs + (this->getOpStats(rhs).get()->*func)();
});
}