diff --git a/include/aidge/graph/StaticAnalysis.hpp b/include/aidge/graph/StaticAnalysis.hpp index 16592041ce04e16930a757578e4f42db453ed645..c668391310f6e8b09d8ece878f89f234b09c961c 100644 --- a/include/aidge/graph/StaticAnalysis.hpp +++ b/include/aidge/graph/StaticAnalysis.hpp @@ -25,7 +25,6 @@ #include "aidge/operator/ConvDepthWise.hpp" #include "aidge/operator/FC.hpp" #include "aidge/operator/MatMul.hpp" -#include "aidge/operator/ReLU.hpp" #include "aidge/operator/ReduceMean.hpp" #include "aidge/operator/ReduceSum.hpp" #include "aidge/operator/Softmax.hpp" @@ -42,7 +41,7 @@ public: const Operator& getOperator() const noexcept { return mOp; } /** - * @brief Get the total number of arithmetic operations for the operator. + * @brief Get the worst case total number of arithmetic operations for the operator. * This includes base arithmetic operations: +, -, / and *. * Example of Operator with only comparison operatons: Conv. * @@ -51,7 +50,7 @@ public: virtual size_t getNbArithmOps() const { return 2 * getNbMACOps(); }; /** - * @brief Get the total number of logic operations for the operator. + * @brief Get the worst case total number of logic operations for the operator. * This includes operations like logical shift, or, and... * Example of Operator with only comparison operatons: BitShift. * @@ -60,7 +59,7 @@ public: virtual size_t getNbLogicOps() const { return 0; }; /** - * @brief Get the total number of comparison operations for the operator. + * @brief Get the worst case total number of comparison operations for the operator. * This includes operations like <, >, =... * Example of Operator with only comparison operatons: MaxPool. * @@ -69,7 +68,7 @@ public: virtual size_t getNbCompOps() const { return 0; }; /** - * @brief Get the total number of non-linear (NL) operations for the operator. + * @brief Get the worst case total number of non-linear (NL) operations for the operator. * This includes operations like calls to tanh(), erf(), cos()... * Example of Operator with only NL operatons: Tanh. * Non-linear operations are necessarily of floating-point type. @@ -79,7 +78,7 @@ public: virtual size_t getNbNLOps() const { return 0; }; /** - * @brief Get the total number of operations for the operator. + * @brief Get the worst case total number of operations for the operator. * Total number of operations = arithmetic ops + logic ops + comp ops + NL ops. * * @return size_t Number of operations. @@ -87,7 +86,7 @@ public: size_t getNbOps() const { return getNbArithmOps() + getNbLogicOps() + getNbCompOps() + getNbNLOps(); }; /** - * @brief Get the total number of INT arithmetic operations for the operator. + * @brief Get the worst case total number of INT arithmetic operations for the operator. * Such that getNbArithmOps() = getNbArithmIntOps() + getNbArithmFpOps() * * @return size_t Number of INT arithmetic operations. @@ -95,7 +94,7 @@ public: virtual size_t getNbArithmIntOps() const; /** - * @brief Get the total number of FP arithmetic operations for the operator. + * @brief Get the worst case total number of FP arithmetic operations for the operator. * Such that getNbArithmOps() = getNbArithmIntOps() + getNbArithmFpOps() * * @return size_t Number of FP arithmetic operations. @@ -103,7 +102,7 @@ public: size_t getNbArithmFpOps() const { return getNbArithmOps() - getNbArithmIntOps(); }; /** - * @brief Get the total number of MAC operations for the operator. + * @brief Get the worst case total number of MAC operations for the operator. * * @return size_t Number of MAC operations. */ @@ -123,6 +122,14 @@ public: StaticAnalysis(std::shared_ptr<GraphView> graph); const std::shared_ptr<GraphView> getGraph() const noexcept { return mGraph; } + /** + * @brief Get the Operator Stats object corresponding to the given node. + * + * @param node Node + * @return std::shared_ptr<OperatorStats> Node's Operator stats + */ + std::shared_ptr<OperatorStats> getOpStats(std::shared_ptr<Node> node) const; + /** * @brief Get the number of parameters associated to a node. This includes * all Producers directly connected to the node's inputs as well as all @@ -165,7 +172,6 @@ public: protected: const std::shared_ptr<GraphView> mGraph; - std::shared_ptr<OperatorStats> getOpStats(std::shared_ptr<Node> node) const; size_t accumulate(size_t (OperatorStats::*func)() const) const; }; @@ -198,6 +204,7 @@ public: size_t getNbMACOps() const override { const OP& op_ = dynamic_cast<const OP&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); const std::size_t weightsSize = op_.getInput(1)->size(); const std::size_t outputSize = std::accumulate(op_.getOutput(0)->dims().cbegin() + 2, @@ -225,6 +232,7 @@ public: size_t getNbMACOps() const override { const FC_Op& op_ = dynamic_cast<const FC_Op&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); const std::size_t weightsSize = op_.getInput(1)->size(); const std::size_t batchSize = op_.getInput(0)->dims()[0]; // NCHW return batchSize * weightsSize; @@ -243,6 +251,7 @@ public: size_t getNbMACOps() const override { const MatMul_Op& op_ = dynamic_cast<const MatMul_Op&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); const size_t n = (op_.getInput(0)->dims().size() > 1) ? op_.getInput(0)->dims().end()[-2] : 1; const size_t k = op_.getInput(0)->dims().back(); @@ -270,13 +279,38 @@ public: } size_t getNbCompOps() const override { - const ReLU_Op& op_ = dynamic_cast<const ReLU_Op&>(mOp); + const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); return op_.getOutput(0)->size(); } }; REGISTRAR(OperatorStats, "ReLU", ReLUStats::create); +class AbsStats : public OperatorStats { +public: + AbsStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<AbsStats> create(const Operator& op) { + return std::make_unique<AbsStats>(op); + } + + size_t getNbCompOps() const override { + const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); + return op_.getOutput(0)->size(); + } + + // This is in the worst case (all values are negative) + size_t getNbArithmOps() const override { + const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); + return op_.getOutput(0)->size(); + } +}; + +REGISTRAR(OperatorStats, "Abs", AbsStats::create); + class ReduceMeanStats : public OperatorStats { public: ReduceMeanStats(const Operator& op) : OperatorStats(op) {} @@ -287,6 +321,7 @@ public: size_t getNbArithmOps() const override { const ReduceMean_Op& op_ = dynamic_cast<const ReduceMean_Op&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); const size_t nbIn = op_.getInput(0)->size(); const size_t nbOut = op_.getOutput(0)->size(); const size_t nbReduce = nbIn / nbOut; @@ -307,6 +342,7 @@ public: size_t getNbArithmOps() const override { const ReduceSum_Op& op_ = dynamic_cast<const ReduceSum_Op&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); const size_t nbIn = op_.getInput(0)->size(); const size_t nbOut = op_.getOutput(0)->size(); const size_t nbReduce = nbIn / nbOut; @@ -327,6 +363,7 @@ public: size_t getNbArithmOps() const override { const Softmax_Op& op_ = dynamic_cast<const Softmax_Op&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); const size_t axis = (op_.axis() >= 0) ? op_.axis() : op_.getInput(0)->nbDims() + op_.axis(); const size_t nbReduce = op_.getInput(0)->dims()[axis]; const size_t nbOut = op_.getOutput(0)->size(); @@ -336,6 +373,7 @@ public: size_t getNbNLOps() const override { const Softmax_Op& op_ = dynamic_cast<const Softmax_Op&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); const size_t axis = (op_.axis() >= 0) ? op_.axis() : op_.getInput(0)->nbDims() + op_.axis(); const size_t nbReduce = op_.getInput(0)->dims()[axis]; const size_t nbOut = op_.getOutput(0)->size(); @@ -375,6 +413,7 @@ public: size_t getNbArithmOps() const override { const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); return op_.getOutput(0)->size(); } }; @@ -384,6 +423,23 @@ REGISTRAR(OperatorStats, "Sub", ElemWiseOpStats::create); REGISTRAR(OperatorStats, "Mul", ElemWiseOpStats::create); REGISTRAR(OperatorStats, "Div", ElemWiseOpStats::create); +class ElemWiseLogicOpStats : public OperatorStats { +public: + ElemWiseLogicOpStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<ElemWiseLogicOpStats> create(const Operator& op) { + return std::make_unique<ElemWiseLogicOpStats>(op); + } + + size_t getNbArithmOps() const override { + const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); + return op_.getOutput(0)->size(); + } +}; + +REGISTRAR(OperatorStats, "And", ElemWiseLogicOpStats::create); + class ElemWiseNLOpStats : public OperatorStats { public: ElemWiseNLOpStats(const Operator& op) : OperatorStats(op) {} @@ -394,10 +450,12 @@ public: size_t getNbNLOps() const override { const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); return op_.getOutput(0)->size(); } }; +REGISTRAR(OperatorStats, "Atan", ElemWiseNLOpStats::create); REGISTRAR(OperatorStats, "Sqrt", ElemWiseNLOpStats::create); REGISTRAR(OperatorStats, "Erf", ElemWiseNLOpStats::create); REGISTRAR(OperatorStats, "Ln", ElemWiseNLOpStats::create); diff --git a/python_binding/graph/pybind_StaticAnalysis.cpp b/python_binding/graph/pybind_StaticAnalysis.cpp index 04a19f6cce9a0d697d4a735161fa887c395f6130..b7c704d722e81b36e8d4988a4503428918e16a5a 100644 --- a/python_binding/graph/pybind_StaticAnalysis.cpp +++ b/python_binding/graph/pybind_StaticAnalysis.cpp @@ -106,12 +106,6 @@ public: } }; -// See https://pybind11.readthedocs.io/en/stable/advanced/classes.html#binding-protected-member-functions -class StaticAnalysis_Publicist : public StaticAnalysis { -public: - using StaticAnalysis::getOpStats; -}; - void init_StaticAnalysis(py::module& m){ py::class_<OperatorStats, std::shared_ptr<OperatorStats>, pyOperatorStats>(m, "OperatorStats", py::multiple_inheritance(), py::dynamic_attr()) .def(py::init<const Operator&>(), py::arg("op")) @@ -141,7 +135,7 @@ void init_StaticAnalysis(py::module& m){ .def("get_nb_arithm_fp_ops", &StaticAnalysis::getNbArithmFpOps) .def("get_nb_mac_ops", &StaticAnalysis::getNbMACOps) .def("summary", &StaticAnalysis::summary, py::arg("inc_producers") = false) - .def("get_op_stats", &StaticAnalysis_Publicist::getOpStats, py::arg("node")) + .def("get_op_stats", &StaticAnalysis::getOpStats, py::arg("node")) ; } } diff --git a/unit_tests/graph/Test_Matching.cpp b/unit_tests/graph/Test_Matching.cpp index d6d98d4701cba900548d127879c9b3940cf1d739..8c5fa222a68a7f2eed329be7c49ca62d0d7ba52f 100644 --- a/unit_tests/graph/Test_Matching.cpp +++ b/unit_tests/graph/Test_Matching.cpp @@ -50,9 +50,9 @@ TEST_CASE("[core/graph] Matching") { ReLU("relu2"), PaddedConv(8, 16, {3, 3}, "conv3", {1, 1}, {2, 2, 2, 2}), ReLU("relu3"), - PaddedConv(8, 16, {5, 5}, "conv4", {1, 1}, {2, 2, 2, 2}), + PaddedConv(16, 16, {5, 5}, "conv4", {1, 1}, {2, 2, 2, 2}), Add("add"), - PaddedConv(8, 16, {5, 5}, "conv5", {1, 1}, {2, 2, 2, 2}), + PaddedConv(16, 16, {5, 5}, "conv5", {1, 1}, {2, 2, 2, 2}), ReLU("relu5"), Add("add2") }); diff --git a/unit_tests/graph/Test_StaticAnalysis.cpp b/unit_tests/graph/Test_StaticAnalysis.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c1a76f0e2a0db4ce092ed109fda71dd3b25741c8 --- /dev/null +++ b/unit_tests/graph/Test_StaticAnalysis.cpp @@ -0,0 +1,68 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> + +#include <fmt/chrono.h> + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/OpArgs.hpp" +#include "aidge/graph/StaticAnalysis.hpp" +#include "aidge/operator/Add.hpp" +#include "aidge/operator/BatchNorm.hpp" +#include "aidge/operator/FC.hpp" +#include "aidge/operator/ReLU.hpp" +#include "aidge/operator/MetaOperatorDefs.hpp" +#include "aidge/operator/Producer.hpp" + +using namespace Aidge; + +TEST_CASE("[core/graph] StaticAnalysis") { + SECTION("Conv") { + auto g1 = Sequential({ + Conv(3, 4, {5, 5}, "conv1"), + ReLU("relu1"), + PaddedConv(4, 8, {5, 5}, "conv2", {1, 1}, {2, 2, 2, 2}), + ReLU("relu2"), + PaddedConv(8, 16, {3, 3}, "conv3", {1, 1}, {2, 2, 2, 2}), + ReLU("relu3"), + PaddedConv(16, 16, {5, 5}, "conv4", {1, 1}, {2, 2, 2, 2}), + Add("add"), + PaddedConv(16, 16, {5, 5}, "conv5", {1, 1}, {2, 2, 2, 2}), + ReLU("relu5"), + Add("add2") + }); + + g1->getNode("relu3")->addChild(g1->getNode("add"), 0, 1); + g1->getNode("conv5")->addChild(g1->getNode("add2"), 0, 1); + g1->updateInputsOutputs(); + + g1->forwardDims({{16, 3, 512, 512}}); + + StaticAnalysis stats(g1); + REQUIRE(stats.getNbParams(g1->getNode("conv1")) == 3 * 4 * 5 * 5 + 4); + REQUIRE(stats.getNbParams(g1->getNode("conv2")) == 4 * 8 * 5 * 5 + 8); + REQUIRE(stats.getNbParams(g1->getNode("conv3")) == 8 * 16 * 3 * 3 + 16); + + const auto conv1Stats = stats.getOpStats(g1->getNode("conv1")); + REQUIRE(conv1Stats->getNbMACOps() == 1L * (16 * 508 * 508) * (5 * 5 * 3 * 4)); + REQUIRE(conv1Stats->getNbArithmOps() == 2L * (16 * 508 * 508) * (5 * 5 * 3 * 4)); + REQUIRE(conv1Stats->getNbArithmFpOps() == 2L * (16 * 508 * 508) * (5 * 5 * 3 * 4)); + REQUIRE(conv1Stats->getNbArithmIntOps() == 0); + + g1->getNode("conv1")->getOperator()->setDataType(DataType::Int8); + REQUIRE(conv1Stats->getNbMACOps() == 1L * (16 * 508 * 508) * (5 * 5 * 3 * 4)); + REQUIRE(conv1Stats->getNbArithmOps() == 2L * (16 * 508 * 508) * (5 * 5 * 3 * 4)); + REQUIRE(conv1Stats->getNbArithmFpOps() == 0); + REQUIRE(conv1Stats->getNbArithmIntOps() == 2L * (16 * 508 * 508) * (5 * 5 * 3 * 4)); + } +}