From b9817f5db13c0a7c5aba93331d7147a8dcbb4ba5 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Thu, 14 Nov 2024 20:53:26 +0100 Subject: [PATCH] Added pooling operations --- include/aidge/graph/StaticAnalysis.hpp | 82 ++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/include/aidge/graph/StaticAnalysis.hpp b/include/aidge/graph/StaticAnalysis.hpp index 94449c562..d92356b72 100644 --- a/include/aidge/graph/StaticAnalysis.hpp +++ b/include/aidge/graph/StaticAnalysis.hpp @@ -20,11 +20,13 @@ #include "aidge/data/Tensor.hpp" #include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/AvgPooling.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/operator/Conv.hpp" #include "aidge/operator/ConvDepthWise.hpp" #include "aidge/operator/FC.hpp" #include "aidge/operator/MatMul.hpp" +#include "aidge/operator/MaxPooling.hpp" #include "aidge/operator/ReduceMean.hpp" #include "aidge/operator/ReduceSum.hpp" #include "aidge/operator/Softmax.hpp" @@ -45,6 +47,8 @@ public: * operator data flow. This includes base arithmetic operations: +, -, / and *. * Control flow operations (loop counters, index computation...) and memory * accesses are not included. + * A naive implementation is considered (more operations might be required + * for numerical stability in an actual implementation). * Example of Operator with only arithmetic operatons: Conv. * * @return size_t Number of arithmetic operations. @@ -56,6 +60,8 @@ public: * operator data flow. This includes operations like logical shift, or, and... * Control flow operations (loop counters, index computation...) and memory * accesses are not included. + * A naive implementation is considered (more operations might be required + * for numerical stability in an actual implementation). * Example of Operator with only logic operatons: BitShift. * * @return size_t Number of logic operations. @@ -67,6 +73,8 @@ public: * operator data flow. This includes operations like <, >, =... * Control flow operations (loop counters, index computation...) and memory * accesses are not included. + * A naive implementation is considered (more operations might be required + * for numerical stability in an actual implementation). * Example of Operator with only comparison operatons: MaxPool. * * @return size_t Number of comparison operations. @@ -78,6 +86,8 @@ public: * operator data flow. This includes operations like calls to tanh(), erf(), cos()... * Control flow operations (loop counters, index computation...) and memory * accesses are not included. + * A naive implementation is considered (more operations might be required + * for numerical stability in an actual implementation). * Example of Operator with only NL operatons: Tanh. * Non-linear operations are necessarily of floating-point type. * @@ -90,6 +100,8 @@ public: * Total number of operations = arithmetic ops + logic ops + comp ops + NL ops. * Control flow operations (loop counters, index computation...) and memory * accesses are not included. + * A naive implementation is considered (more operations might be required + * for numerical stability in an actual implementation). * * @return size_t Number of operations. */ @@ -101,6 +113,8 @@ public: * Such that getNbArithmOps() = getNbArithmIntOps() + getNbArithmFpOps() * Control flow operations (loop counters, index computation...) and memory * accesses are not included. + * A naive implementation is considered (more operations might be required + * for numerical stability in an actual implementation). * * @return size_t Number of INT arithmetic operations. */ @@ -112,6 +126,8 @@ public: * Such that getNbArithmOps() = getNbArithmIntOps() + getNbArithmFpOps() * Control flow operations (loop counters, index computation...) and memory * accesses are not included. + * A naive implementation is considered (more operations might be required + * for numerical stability in an actual implementation). * * @return size_t Number of FP arithmetic operations. */ @@ -123,6 +139,8 @@ public: * operation counted as 2 arithmetic operations. MAC can be INT of FP. * Control flow operations (loop counters, index computation...) and memory * accesses are not included. + * A naive implementation is considered (more operations might be required + * for numerical stability in an actual implementation). * * @return size_t Number of MAC operations. */ @@ -244,6 +262,70 @@ REGISTRAR(OperatorStats, "ConvDepthWise1D", ConvStats<ConvDepthWise_Op<1>>::crea REGISTRAR(OperatorStats, "Conv2D", ConvStats<Conv_Op<2>>::create); REGISTRAR(OperatorStats, "ConvDepthWise2D", ConvStats<ConvDepthWise_Op<2>>::create); +template <class OP> +class MaxPoolingStats : public OperatorStats { +public: + MaxPoolingStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<MaxPoolingStats<OP>> create(const Operator& op) { + return std::make_unique<MaxPoolingStats<OP>>(op); + } + + size_t getNbCompOps() const override { + const OP& op_ = dynamic_cast<const OP&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); + const std::size_t poolSize + = std::accumulate(op_.kernelDims().cbegin(), + op_.kernelDims().cend(), + 1, + std::multiplies<size_t>()); + const std::size_t outputSize + = std::accumulate(op_.getOutput(0)->dims().cbegin() + 2, + op_.getOutput(0)->dims().cend(), + 1, + std::multiplies<size_t>()); // NCHW... + const std::size_t batchSize = op_.getInput(0)->dims()[0]; // NCHW + return batchSize * ((poolSize - 1) * outputSize); + } +}; + +REGISTRAR(OperatorStats, "MaxPooling1D", MaxPoolingStats<MaxPooling_Op<1>>::create); +REGISTRAR(OperatorStats, "MaxPooling2D", MaxPoolingStats<MaxPooling_Op<2>>::create); +REGISTRAR(OperatorStats, "MaxPooling3D", MaxPoolingStats<MaxPooling_Op<3>>::create); + +template <class OP> +class AvgPoolingStats : public OperatorStats { +public: + AvgPoolingStats(const Operator& op) : OperatorStats(op) {} + + static std::unique_ptr<AvgPoolingStats<OP>> create(const Operator& op) { + return std::make_unique<AvgPoolingStats<OP>>(op); + } + + size_t getNbArithmOps() const override { + const OP& op_ = dynamic_cast<const OP&>(mOp); + AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis"); + const std::size_t poolSize + = std::accumulate(op_.kernelDims().cbegin(), + op_.kernelDims().cend(), + 1, + std::multiplies<size_t>()); + const std::size_t outputSize + = std::accumulate(op_.getOutput(0)->dims().cbegin() + 2, + op_.getOutput(0)->dims().cend(), + 1, + std::multiplies<size_t>()); // NCHW... + const std::size_t batchSize = op_.getInput(0)->dims()[0]; // NCHW + // (poolSize - 1) additions + 1 division for each output + return batchSize * (poolSize * outputSize); + } +}; + +REGISTRAR(OperatorStats, "AvgPooling1D", AvgPoolingStats<AvgPooling_Op<1>>::create); +REGISTRAR(OperatorStats, "AvgPooling2D", AvgPoolingStats<AvgPooling_Op<2>>::create); +REGISTRAR(OperatorStats, "AvgPooling3D", AvgPoolingStats<AvgPooling_Op<3>>::create); +REGISTRAR(OperatorStats, "AvgPooling4D", AvgPoolingStats<AvgPooling_Op<4>>::create); + class FCStats : public OperatorStats { public: FCStats(const Operator& op) : OperatorStats(op) {} -- GitLab