From b9817f5db13c0a7c5aba93331d7147a8dcbb4ba5 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Thu, 14 Nov 2024 20:53:26 +0100
Subject: [PATCH] Added pooling operations

---
 include/aidge/graph/StaticAnalysis.hpp | 82 ++++++++++++++++++++++++++
 1 file changed, 82 insertions(+)

diff --git a/include/aidge/graph/StaticAnalysis.hpp b/include/aidge/graph/StaticAnalysis.hpp
index 94449c562..d92356b72 100644
--- a/include/aidge/graph/StaticAnalysis.hpp
+++ b/include/aidge/graph/StaticAnalysis.hpp
@@ -20,11 +20,13 @@
 #include "aidge/data/Tensor.hpp"
 #include "aidge/operator/OperatorTensor.hpp"
 
+#include "aidge/operator/AvgPooling.hpp"
 #include "aidge/operator/Producer.hpp"
 #include "aidge/operator/Conv.hpp"
 #include "aidge/operator/ConvDepthWise.hpp"
 #include "aidge/operator/FC.hpp"
 #include "aidge/operator/MatMul.hpp"
+#include "aidge/operator/MaxPooling.hpp"
 #include "aidge/operator/ReduceMean.hpp"
 #include "aidge/operator/ReduceSum.hpp"
 #include "aidge/operator/Softmax.hpp"
@@ -45,6 +47,8 @@ public:
      * operator data flow. This includes base arithmetic operations: +, -, / and *.
      * Control flow operations (loop counters, index computation...) and memory 
      * accesses are not included.
+     * A naive implementation is considered (more operations might be required 
+     * for numerical stability in an actual implementation).
      * Example of Operator with only arithmetic operatons: Conv.
      * 
      * @return size_t Number of arithmetic operations.
@@ -56,6 +60,8 @@ public:
      * operator data flow. This includes operations like logical shift, or, and...
      * Control flow operations (loop counters, index computation...) and memory 
      * accesses are not included.
+     * A naive implementation is considered (more operations might be required 
+     * for numerical stability in an actual implementation).
      * Example of Operator with only logic operatons: BitShift.
      * 
      * @return size_t Number of logic operations.
@@ -67,6 +73,8 @@ public:
      * operator data flow. This includes operations like <, >, =...
      * Control flow operations (loop counters, index computation...) and memory 
      * accesses are not included.
+     * A naive implementation is considered (more operations might be required 
+     * for numerical stability in an actual implementation).
      * Example of Operator with only comparison operatons: MaxPool.
      * 
      * @return size_t Number of comparison operations.
@@ -78,6 +86,8 @@ public:
      * operator data flow. This includes operations like calls to tanh(), erf(), cos()...
      * Control flow operations (loop counters, index computation...) and memory 
      * accesses are not included.
+     * A naive implementation is considered (more operations might be required 
+     * for numerical stability in an actual implementation).
      * Example of Operator with only NL operatons: Tanh.
      * Non-linear operations are necessarily of floating-point type.
      * 
@@ -90,6 +100,8 @@ public:
      * Total number of operations = arithmetic ops + logic ops + comp ops + NL ops.
      * Control flow operations (loop counters, index computation...) and memory 
      * accesses are not included.
+     * A naive implementation is considered (more operations might be required 
+     * for numerical stability in an actual implementation).
      * 
      * @return size_t Number of operations.
      */
@@ -101,6 +113,8 @@ public:
      * Such that getNbArithmOps() = getNbArithmIntOps() + getNbArithmFpOps()
      * Control flow operations (loop counters, index computation...) and memory 
      * accesses are not included.
+     * A naive implementation is considered (more operations might be required 
+     * for numerical stability in an actual implementation).
      * 
      * @return size_t Number of INT arithmetic operations.
      */
@@ -112,6 +126,8 @@ public:
      * Such that getNbArithmOps() = getNbArithmIntOps() + getNbArithmFpOps()
      * Control flow operations (loop counters, index computation...) and memory 
      * accesses are not included.
+     * A naive implementation is considered (more operations might be required 
+     * for numerical stability in an actual implementation).
      * 
      * @return size_t Number of FP arithmetic operations.
      */
@@ -123,6 +139,8 @@ public:
      * operation counted as 2 arithmetic operations. MAC can be INT of FP.
      * Control flow operations (loop counters, index computation...) and memory 
      * accesses are not included.
+     * A naive implementation is considered (more operations might be required 
+     * for numerical stability in an actual implementation).
      * 
      * @return size_t Number of MAC operations.
      */
@@ -244,6 +262,70 @@ REGISTRAR(OperatorStats, "ConvDepthWise1D", ConvStats<ConvDepthWise_Op<1>>::crea
 REGISTRAR(OperatorStats, "Conv2D", ConvStats<Conv_Op<2>>::create);
 REGISTRAR(OperatorStats, "ConvDepthWise2D", ConvStats<ConvDepthWise_Op<2>>::create);
 
+template <class OP>
+class MaxPoolingStats : public OperatorStats {
+public:
+    MaxPoolingStats(const Operator& op) : OperatorStats(op) {}
+
+    static std::unique_ptr<MaxPoolingStats<OP>> create(const Operator& op) {
+        return std::make_unique<MaxPoolingStats<OP>>(op);
+    }
+
+    size_t getNbCompOps() const override {
+        const OP& op_ = dynamic_cast<const OP&>(mOp);
+        AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
+	    const std::size_t poolSize
+            = std::accumulate(op_.kernelDims().cbegin(),
+                              op_.kernelDims().cend(),
+                              1,
+                              std::multiplies<size_t>());
+        const std::size_t outputSize
+            = std::accumulate(op_.getOutput(0)->dims().cbegin() + 2,
+                              op_.getOutput(0)->dims().cend(),
+                              1,
+                              std::multiplies<size_t>()); // NCHW...
+        const std::size_t batchSize = op_.getInput(0)->dims()[0]; // NCHW
+        return batchSize * ((poolSize - 1) * outputSize);
+    }
+};
+
+REGISTRAR(OperatorStats, "MaxPooling1D", MaxPoolingStats<MaxPooling_Op<1>>::create);
+REGISTRAR(OperatorStats, "MaxPooling2D", MaxPoolingStats<MaxPooling_Op<2>>::create);
+REGISTRAR(OperatorStats, "MaxPooling3D", MaxPoolingStats<MaxPooling_Op<3>>::create);
+
+template <class OP>
+class AvgPoolingStats : public OperatorStats {
+public:
+    AvgPoolingStats(const Operator& op) : OperatorStats(op) {}
+
+    static std::unique_ptr<AvgPoolingStats<OP>> create(const Operator& op) {
+        return std::make_unique<AvgPoolingStats<OP>>(op);
+    }
+
+    size_t getNbArithmOps() const override {
+        const OP& op_ = dynamic_cast<const OP&>(mOp);
+        AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
+	    const std::size_t poolSize
+            = std::accumulate(op_.kernelDims().cbegin(),
+                              op_.kernelDims().cend(),
+                              1,
+                              std::multiplies<size_t>());
+        const std::size_t outputSize
+            = std::accumulate(op_.getOutput(0)->dims().cbegin() + 2,
+                              op_.getOutput(0)->dims().cend(),
+                              1,
+                              std::multiplies<size_t>()); // NCHW...
+        const std::size_t batchSize = op_.getInput(0)->dims()[0]; // NCHW
+        // (poolSize - 1) additions + 1 division for each output
+        return batchSize * (poolSize * outputSize);
+    }
+};
+
+REGISTRAR(OperatorStats, "AvgPooling1D", AvgPoolingStats<AvgPooling_Op<1>>::create);
+REGISTRAR(OperatorStats, "AvgPooling2D", AvgPoolingStats<AvgPooling_Op<2>>::create);
+REGISTRAR(OperatorStats, "AvgPooling3D", AvgPoolingStats<AvgPooling_Op<3>>::create);
+REGISTRAR(OperatorStats, "AvgPooling4D", AvgPoolingStats<AvgPooling_Op<4>>::create);
+
 class FCStats : public OperatorStats {
 public:
     FCStats(const Operator& op) : OperatorStats(op) {}
-- 
GitLab