Skip to content
Snippets Groups Projects
Commit 2d042cf7 authored by Olivier BICHLER's avatar Olivier BICHLER Committed by Axel Farrugia
Browse files

Several improvements

parent 035cade1
No related branches found
No related tags found
2 merge requests!279v0.4.0,!250[Feat](Exports) Add custom options to exports
......@@ -25,7 +25,6 @@
#include "aidge/operator/ConvDepthWise.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/operator/ReduceMean.hpp"
#include "aidge/operator/ReduceSum.hpp"
#include "aidge/operator/Softmax.hpp"
......@@ -42,7 +41,7 @@ public:
const Operator& getOperator() const noexcept { return mOp; }
/**
* @brief Get the total number of arithmetic operations for the operator.
* @brief Get the worst case total number of arithmetic operations for the operator.
* This includes base arithmetic operations: +, -, / and *.
* Example of Operator with only comparison operatons: Conv.
*
......@@ -51,7 +50,7 @@ public:
virtual size_t getNbArithmOps() const { return 2 * getNbMACOps(); };
/**
* @brief Get the total number of logic operations for the operator.
* @brief Get the worst case total number of logic operations for the operator.
* This includes operations like logical shift, or, and...
* Example of Operator with only comparison operatons: BitShift.
*
......@@ -60,7 +59,7 @@ public:
virtual size_t getNbLogicOps() const { return 0; };
/**
* @brief Get the total number of comparison operations for the operator.
* @brief Get the worst case total number of comparison operations for the operator.
* This includes operations like <, >, =...
* Example of Operator with only comparison operatons: MaxPool.
*
......@@ -69,7 +68,7 @@ public:
virtual size_t getNbCompOps() const { return 0; };
/**
* @brief Get the total number of non-linear (NL) operations for the operator.
* @brief Get the worst case total number of non-linear (NL) operations for the operator.
* This includes operations like calls to tanh(), erf(), cos()...
* Example of Operator with only NL operatons: Tanh.
* Non-linear operations are necessarily of floating-point type.
......@@ -79,7 +78,7 @@ public:
virtual size_t getNbNLOps() const { return 0; };
/**
* @brief Get the total number of operations for the operator.
* @brief Get the worst case total number of operations for the operator.
* Total number of operations = arithmetic ops + logic ops + comp ops + NL ops.
*
* @return size_t Number of operations.
......@@ -87,7 +86,7 @@ public:
size_t getNbOps() const { return getNbArithmOps() + getNbLogicOps() + getNbCompOps() + getNbNLOps(); };
/**
* @brief Get the total number of INT arithmetic operations for the operator.
* @brief Get the worst case total number of INT arithmetic operations for the operator.
* Such that getNbArithmOps() = getNbArithmIntOps() + getNbArithmFpOps()
*
* @return size_t Number of INT arithmetic operations.
......@@ -95,7 +94,7 @@ public:
virtual size_t getNbArithmIntOps() const;
/**
* @brief Get the total number of FP arithmetic operations for the operator.
* @brief Get the worst case total number of FP arithmetic operations for the operator.
* Such that getNbArithmOps() = getNbArithmIntOps() + getNbArithmFpOps()
*
* @return size_t Number of FP arithmetic operations.
......@@ -103,7 +102,7 @@ public:
size_t getNbArithmFpOps() const { return getNbArithmOps() - getNbArithmIntOps(); };
/**
* @brief Get the total number of MAC operations for the operator.
* @brief Get the worst case total number of MAC operations for the operator.
*
* @return size_t Number of MAC operations.
*/
......@@ -123,6 +122,14 @@ public:
StaticAnalysis(std::shared_ptr<GraphView> graph);
const std::shared_ptr<GraphView> getGraph() const noexcept { return mGraph; }
/**
* @brief Get the Operator Stats object corresponding to the given node.
*
* @param node Node
* @return std::shared_ptr<OperatorStats> Node's Operator stats
*/
std::shared_ptr<OperatorStats> getOpStats(std::shared_ptr<Node> node) const;
/**
* @brief Get the number of parameters associated to a node. This includes
* all Producers directly connected to the node's inputs as well as all
......@@ -165,7 +172,6 @@ public:
protected:
const std::shared_ptr<GraphView> mGraph;
std::shared_ptr<OperatorStats> getOpStats(std::shared_ptr<Node> node) const;
size_t accumulate(size_t (OperatorStats::*func)() const) const;
};
......@@ -198,6 +204,7 @@ public:
size_t getNbMACOps() const override {
const OP& op_ = dynamic_cast<const OP&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
const std::size_t weightsSize = op_.getInput(1)->size();
const std::size_t outputSize
= std::accumulate(op_.getOutput(0)->dims().cbegin() + 2,
......@@ -225,6 +232,7 @@ public:
size_t getNbMACOps() const override {
const FC_Op& op_ = dynamic_cast<const FC_Op&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
const std::size_t weightsSize = op_.getInput(1)->size();
const std::size_t batchSize = op_.getInput(0)->dims()[0]; // NCHW
return batchSize * weightsSize;
......@@ -243,6 +251,7 @@ public:
size_t getNbMACOps() const override {
const MatMul_Op& op_ = dynamic_cast<const MatMul_Op&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
const size_t n = (op_.getInput(0)->dims().size() > 1)
? op_.getInput(0)->dims().end()[-2] : 1;
const size_t k = op_.getInput(0)->dims().back();
......@@ -270,13 +279,38 @@ public:
}
size_t getNbCompOps() const override {
const ReLU_Op& op_ = dynamic_cast<const ReLU_Op&>(mOp);
const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
return op_.getOutput(0)->size();
}
};
REGISTRAR(OperatorStats, "ReLU", ReLUStats::create);
class AbsStats : public OperatorStats {
public:
AbsStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<AbsStats> create(const Operator& op) {
return std::make_unique<AbsStats>(op);
}
size_t getNbCompOps() const override {
const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
return op_.getOutput(0)->size();
}
// This is in the worst case (all values are negative)
size_t getNbArithmOps() const override {
const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
return op_.getOutput(0)->size();
}
};
REGISTRAR(OperatorStats, "Abs", AbsStats::create);
class ReduceMeanStats : public OperatorStats {
public:
ReduceMeanStats(const Operator& op) : OperatorStats(op) {}
......@@ -287,6 +321,7 @@ public:
size_t getNbArithmOps() const override {
const ReduceMean_Op& op_ = dynamic_cast<const ReduceMean_Op&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
const size_t nbIn = op_.getInput(0)->size();
const size_t nbOut = op_.getOutput(0)->size();
const size_t nbReduce = nbIn / nbOut;
......@@ -307,6 +342,7 @@ public:
size_t getNbArithmOps() const override {
const ReduceSum_Op& op_ = dynamic_cast<const ReduceSum_Op&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
const size_t nbIn = op_.getInput(0)->size();
const size_t nbOut = op_.getOutput(0)->size();
const size_t nbReduce = nbIn / nbOut;
......@@ -327,6 +363,7 @@ public:
size_t getNbArithmOps() const override {
const Softmax_Op& op_ = dynamic_cast<const Softmax_Op&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
const size_t axis = (op_.axis() >= 0) ? op_.axis() : op_.getInput(0)->nbDims() + op_.axis();
const size_t nbReduce = op_.getInput(0)->dims()[axis];
const size_t nbOut = op_.getOutput(0)->size();
......@@ -336,6 +373,7 @@ public:
size_t getNbNLOps() const override {
const Softmax_Op& op_ = dynamic_cast<const Softmax_Op&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
const size_t axis = (op_.axis() >= 0) ? op_.axis() : op_.getInput(0)->nbDims() + op_.axis();
const size_t nbReduce = op_.getInput(0)->dims()[axis];
const size_t nbOut = op_.getOutput(0)->size();
......@@ -375,6 +413,7 @@ public:
size_t getNbArithmOps() const override {
const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
return op_.getOutput(0)->size();
}
};
......@@ -384,6 +423,23 @@ REGISTRAR(OperatorStats, "Sub", ElemWiseOpStats::create);
REGISTRAR(OperatorStats, "Mul", ElemWiseOpStats::create);
REGISTRAR(OperatorStats, "Div", ElemWiseOpStats::create);
class ElemWiseLogicOpStats : public OperatorStats {
public:
ElemWiseLogicOpStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<ElemWiseLogicOpStats> create(const Operator& op) {
return std::make_unique<ElemWiseLogicOpStats>(op);
}
size_t getNbArithmOps() const override {
const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
return op_.getOutput(0)->size();
}
};
REGISTRAR(OperatorStats, "And", ElemWiseLogicOpStats::create);
class ElemWiseNLOpStats : public OperatorStats {
public:
ElemWiseNLOpStats(const Operator& op) : OperatorStats(op) {}
......@@ -394,10 +450,12 @@ public:
size_t getNbNLOps() const override {
const OperatorTensor& op_ = dynamic_cast<const OperatorTensor&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
return op_.getOutput(0)->size();
}
};
REGISTRAR(OperatorStats, "Atan", ElemWiseNLOpStats::create);
REGISTRAR(OperatorStats, "Sqrt", ElemWiseNLOpStats::create);
REGISTRAR(OperatorStats, "Erf", ElemWiseNLOpStats::create);
REGISTRAR(OperatorStats, "Ln", ElemWiseNLOpStats::create);
......
......@@ -106,12 +106,6 @@ public:
}
};
// See https://pybind11.readthedocs.io/en/stable/advanced/classes.html#binding-protected-member-functions
class StaticAnalysis_Publicist : public StaticAnalysis {
public:
using StaticAnalysis::getOpStats;
};
void init_StaticAnalysis(py::module& m){
py::class_<OperatorStats, std::shared_ptr<OperatorStats>, pyOperatorStats>(m, "OperatorStats", py::multiple_inheritance(), py::dynamic_attr())
.def(py::init<const Operator&>(), py::arg("op"))
......@@ -141,7 +135,7 @@ void init_StaticAnalysis(py::module& m){
.def("get_nb_arithm_fp_ops", &StaticAnalysis::getNbArithmFpOps)
.def("get_nb_mac_ops", &StaticAnalysis::getNbMACOps)
.def("summary", &StaticAnalysis::summary, py::arg("inc_producers") = false)
.def("get_op_stats", &StaticAnalysis_Publicist::getOpStats, py::arg("node"))
.def("get_op_stats", &StaticAnalysis::getOpStats, py::arg("node"))
;
}
}
......@@ -50,9 +50,9 @@ TEST_CASE("[core/graph] Matching") {
ReLU("relu2"),
PaddedConv(8, 16, {3, 3}, "conv3", {1, 1}, {2, 2, 2, 2}),
ReLU("relu3"),
PaddedConv(8, 16, {5, 5}, "conv4", {1, 1}, {2, 2, 2, 2}),
PaddedConv(16, 16, {5, 5}, "conv4", {1, 1}, {2, 2, 2, 2}),
Add("add"),
PaddedConv(8, 16, {5, 5}, "conv5", {1, 1}, {2, 2, 2, 2}),
PaddedConv(16, 16, {5, 5}, "conv5", {1, 1}, {2, 2, 2, 2}),
ReLU("relu5"),
Add("add2")
});
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <fmt/chrono.h>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/graph/StaticAnalysis.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/BatchNorm.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/operator/MetaOperatorDefs.hpp"
#include "aidge/operator/Producer.hpp"
using namespace Aidge;
TEST_CASE("[core/graph] StaticAnalysis") {
SECTION("Conv") {
auto g1 = Sequential({
Conv(3, 4, {5, 5}, "conv1"),
ReLU("relu1"),
PaddedConv(4, 8, {5, 5}, "conv2", {1, 1}, {2, 2, 2, 2}),
ReLU("relu2"),
PaddedConv(8, 16, {3, 3}, "conv3", {1, 1}, {2, 2, 2, 2}),
ReLU("relu3"),
PaddedConv(16, 16, {5, 5}, "conv4", {1, 1}, {2, 2, 2, 2}),
Add("add"),
PaddedConv(16, 16, {5, 5}, "conv5", {1, 1}, {2, 2, 2, 2}),
ReLU("relu5"),
Add("add2")
});
g1->getNode("relu3")->addChild(g1->getNode("add"), 0, 1);
g1->getNode("conv5")->addChild(g1->getNode("add2"), 0, 1);
g1->updateInputsOutputs();
g1->forwardDims({{16, 3, 512, 512}});
StaticAnalysis stats(g1);
REQUIRE(stats.getNbParams(g1->getNode("conv1")) == 3 * 4 * 5 * 5 + 4);
REQUIRE(stats.getNbParams(g1->getNode("conv2")) == 4 * 8 * 5 * 5 + 8);
REQUIRE(stats.getNbParams(g1->getNode("conv3")) == 8 * 16 * 3 * 3 + 16);
const auto conv1Stats = stats.getOpStats(g1->getNode("conv1"));
REQUIRE(conv1Stats->getNbMACOps() == 1L * (16 * 508 * 508) * (5 * 5 * 3 * 4));
REQUIRE(conv1Stats->getNbArithmOps() == 2L * (16 * 508 * 508) * (5 * 5 * 3 * 4));
REQUIRE(conv1Stats->getNbArithmFpOps() == 2L * (16 * 508 * 508) * (5 * 5 * 3 * 4));
REQUIRE(conv1Stats->getNbArithmIntOps() == 0);
g1->getNode("conv1")->getOperator()->setDataType(DataType::Int8);
REQUIRE(conv1Stats->getNbMACOps() == 1L * (16 * 508 * 508) * (5 * 5 * 3 * 4));
REQUIRE(conv1Stats->getNbArithmOps() == 2L * (16 * 508 * 508) * (5 * 5 * 3 * 4));
REQUIRE(conv1Stats->getNbArithmFpOps() == 0);
REQUIRE(conv1Stats->getNbArithmIntOps() == 2L * (16 * 508 * 508) * (5 * 5 * 3 * 4));
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment