Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
StaticAnalysis.cpp 5.74 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include "aidge/graph/StaticAnalysis.hpp"

Aidge::OperatorStats::OperatorStats(const Operator& op)
  : mOp(op)
{
    //ctor
}

size_t Aidge::OperatorStats::getNbParams() const {
    return (getNbFixedParams() + getNbTrainableParams());
}

size_t Aidge::OperatorStats::getNbTrainableParams() const {
    size_t nbParams = 0;
    const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp);
    if (opTensor) {
        for (size_t i = 0; i < mOp.nbInputs(); ++i) {
            if ((mOp.inputCategory(i) == InputCategory::Param
                    || mOp.inputCategory(i) == InputCategory::OptionalParam)
                && opTensor->getInput(i))
            {
                nbParams += opTensor->getInput(i)->size();
            }
        }
    }
    return nbParams;
}

size_t Aidge::OperatorStats::getParamsSize() const {
    size_t paramsSize = 0;
    const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp);
    if (opTensor) {
        for (size_t i = 0; i < mOp.nbInputs(); ++i) {
            if ((mOp.inputCategory(i) == InputCategory::Param
                    || mOp.inputCategory(i) == InputCategory::OptionalParam)
                && opTensor->getInput(i))
            {
                paramsSize += opTensor->getInput(i)->size() * getDataTypeBitWidth(opTensor->getInput(i)->dataType());
            }
        }
    }
    return paramsSize;
}

size_t Aidge::OperatorStats::getNbArithmIntOps() const {
    const auto opTensor = dynamic_cast<const OperatorTensor*>(&mOp);
    if (opTensor) {
        if (!isDataTypeFloatingPoint(opTensor->getOutput(0)->dataType())) {
            return getNbArithmOps();
        }
    }
    return 0;
}

Aidge::StaticAnalysis::StaticAnalysis(std::shared_ptr<GraphView> graph)
  : mGraph(graph)
{
    //ctor
}

void Aidge::StaticAnalysis::summary(bool incProducers) const {
    fmt::println("--------------------------------------------------------------------------------");
    fmt::println("                        Layer (type)               Output Shape         Param #");
    fmt::println("================================================================================");

    size_t nbTrainableParams = 0;
    size_t nbFixedParams = 0;
    size_t paramsSize = 0;
    size_t fwdBwdSize = 0;

    const auto namePtrTable = mGraph->getRankedNodesName("{0} ({1}#{3})");
    for (const auto node : mGraph->getOrderedNodes()) {
        if (node->type() == Producer_Op::Type && !incProducers) {
            continue;
        }

        auto opTensor = std::dynamic_pointer_cast<OperatorTensor>(node->getOperator());
        std::string outputDimsStr = fmt::format("{: >27}", "?");
        if (opTensor) {
            const auto outputDims = opTensor->getOutput(0)->dims();
            outputDimsStr = fmt::format("{: >27}", fmt::format("{}", outputDims));
  
            for (size_t out = 0; out < node->nbOutputs(); ++out) {
                const auto output = opTensor->getOutput(out);
                if (output && node->type() != Producer_Op::Type) {
                    fwdBwdSize += output->size();
                }
            }
        }

        const auto stats = getOpStats(node);
        nbTrainableParams += stats->getNbTrainableParams();
        nbFixedParams += stats->getNbFixedParams();
        paramsSize += stats->getParamsSize();
        fmt::println("{: >36}{}{: >16}",
          namePtrTable.at(node), outputDimsStr, stats->getNbParams());
    }

    size_t inputSize = 0;
    for (const auto input : mGraph->getOrderedInputs()) {
        if (input.first) {
            auto opTensor = std::dynamic_pointer_cast<OperatorTensor>(input.first->getOperator());
            if (opTensor && opTensor->getInput(input.second)) {
                inputSize += opTensor->getInput(input.second)->size();
            }
        }
    }

    fmt::println("================================================================================");
    fmt::println("Total params: {}", nbTrainableParams + nbFixedParams);
    fmt::println("Trainable params: {}", nbTrainableParams);
    fmt::println("Non-trainable params: {}", nbFixedParams);
    fmt::println("--------------------------------------------------------------------------------");
    fmt::println("Input size (MB): {}", inputSize / 8.0 / 1024 / 1024);
    fmt::println("Forward/backward pass size (MB): {}", fwdBwdSize / 8.0 / 1024 / 1024);
    fmt::println("Params size (MB): {}", paramsSize / 8.0 / 1024 / 1024);
    fmt::println("Estimated Total Size (MB): {}", (inputSize + fwdBwdSize + paramsSize) / 8.0 / 1024 / 1024);
    fmt::println("--------------------------------------------------------------------------------");
}

std::shared_ptr<Aidge::OperatorStats> Aidge::StaticAnalysis::getOpStats(std::shared_ptr<Node> node) const {
    return (Registrar<OperatorStats>::exists(node->type()))
        ? Registrar<OperatorStats>::create(node->type())(*(node->getOperator()))
        : (node->getOperator()->isAtomic())
            ? std::make_shared<OperatorStats>(*(node->getOperator()))
            : std::make_shared<MetaOpStats>(*(node->getOperator()));
}

size_t Aidge::StaticAnalysis::accumulate(size_t (OperatorStats::*func)() const) const {
    return std::accumulate(
        mGraph->getNodes().cbegin(),
        mGraph->getNodes().cend(),
        std::size_t(0),
        [this, func](const size_t& lhs, const std::shared_ptr<Node>& rhs) {
            return lhs + (this->getOpStats(rhs).get()->*func)();
        });
}