Skip to content
Snippets Groups Projects
Commit b9817f5d authored by Olivier BICHLER's avatar Olivier BICHLER Committed by Axel Farrugia
Browse files

Added pooling operations

parent 0654c8df
No related branches found
No related tags found
2 merge requests!279v0.4.0,!250[Feat](Exports) Add custom options to exports
This commit is part of merge request !250. Comments created here will be created in the context of that merge request.
...@@ -20,11 +20,13 @@ ...@@ -20,11 +20,13 @@
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/AvgPooling.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/operator/Conv.hpp" #include "aidge/operator/Conv.hpp"
#include "aidge/operator/ConvDepthWise.hpp" #include "aidge/operator/ConvDepthWise.hpp"
#include "aidge/operator/FC.hpp" #include "aidge/operator/FC.hpp"
#include "aidge/operator/MatMul.hpp" #include "aidge/operator/MatMul.hpp"
#include "aidge/operator/MaxPooling.hpp"
#include "aidge/operator/ReduceMean.hpp" #include "aidge/operator/ReduceMean.hpp"
#include "aidge/operator/ReduceSum.hpp" #include "aidge/operator/ReduceSum.hpp"
#include "aidge/operator/Softmax.hpp" #include "aidge/operator/Softmax.hpp"
...@@ -45,6 +47,8 @@ public: ...@@ -45,6 +47,8 @@ public:
* operator data flow. This includes base arithmetic operations: +, -, / and *. * operator data flow. This includes base arithmetic operations: +, -, / and *.
* Control flow operations (loop counters, index computation...) and memory * Control flow operations (loop counters, index computation...) and memory
* accesses are not included. * accesses are not included.
* A naive implementation is considered (more operations might be required
* for numerical stability in an actual implementation).
* Example of Operator with only arithmetic operatons: Conv. * Example of Operator with only arithmetic operatons: Conv.
* *
* @return size_t Number of arithmetic operations. * @return size_t Number of arithmetic operations.
...@@ -56,6 +60,8 @@ public: ...@@ -56,6 +60,8 @@ public:
* operator data flow. This includes operations like logical shift, or, and... * operator data flow. This includes operations like logical shift, or, and...
* Control flow operations (loop counters, index computation...) and memory * Control flow operations (loop counters, index computation...) and memory
* accesses are not included. * accesses are not included.
* A naive implementation is considered (more operations might be required
* for numerical stability in an actual implementation).
* Example of Operator with only logic operatons: BitShift. * Example of Operator with only logic operatons: BitShift.
* *
* @return size_t Number of logic operations. * @return size_t Number of logic operations.
...@@ -67,6 +73,8 @@ public: ...@@ -67,6 +73,8 @@ public:
* operator data flow. This includes operations like <, >, =... * operator data flow. This includes operations like <, >, =...
* Control flow operations (loop counters, index computation...) and memory * Control flow operations (loop counters, index computation...) and memory
* accesses are not included. * accesses are not included.
* A naive implementation is considered (more operations might be required
* for numerical stability in an actual implementation).
* Example of Operator with only comparison operatons: MaxPool. * Example of Operator with only comparison operatons: MaxPool.
* *
* @return size_t Number of comparison operations. * @return size_t Number of comparison operations.
...@@ -78,6 +86,8 @@ public: ...@@ -78,6 +86,8 @@ public:
* operator data flow. This includes operations like calls to tanh(), erf(), cos()... * operator data flow. This includes operations like calls to tanh(), erf(), cos()...
* Control flow operations (loop counters, index computation...) and memory * Control flow operations (loop counters, index computation...) and memory
* accesses are not included. * accesses are not included.
* A naive implementation is considered (more operations might be required
* for numerical stability in an actual implementation).
* Example of Operator with only NL operatons: Tanh. * Example of Operator with only NL operatons: Tanh.
* Non-linear operations are necessarily of floating-point type. * Non-linear operations are necessarily of floating-point type.
* *
...@@ -90,6 +100,8 @@ public: ...@@ -90,6 +100,8 @@ public:
* Total number of operations = arithmetic ops + logic ops + comp ops + NL ops. * Total number of operations = arithmetic ops + logic ops + comp ops + NL ops.
* Control flow operations (loop counters, index computation...) and memory * Control flow operations (loop counters, index computation...) and memory
* accesses are not included. * accesses are not included.
* A naive implementation is considered (more operations might be required
* for numerical stability in an actual implementation).
* *
* @return size_t Number of operations. * @return size_t Number of operations.
*/ */
...@@ -101,6 +113,8 @@ public: ...@@ -101,6 +113,8 @@ public:
* Such that getNbArithmOps() = getNbArithmIntOps() + getNbArithmFpOps() * Such that getNbArithmOps() = getNbArithmIntOps() + getNbArithmFpOps()
* Control flow operations (loop counters, index computation...) and memory * Control flow operations (loop counters, index computation...) and memory
* accesses are not included. * accesses are not included.
* A naive implementation is considered (more operations might be required
* for numerical stability in an actual implementation).
* *
* @return size_t Number of INT arithmetic operations. * @return size_t Number of INT arithmetic operations.
*/ */
...@@ -112,6 +126,8 @@ public: ...@@ -112,6 +126,8 @@ public:
* Such that getNbArithmOps() = getNbArithmIntOps() + getNbArithmFpOps() * Such that getNbArithmOps() = getNbArithmIntOps() + getNbArithmFpOps()
* Control flow operations (loop counters, index computation...) and memory * Control flow operations (loop counters, index computation...) and memory
* accesses are not included. * accesses are not included.
* A naive implementation is considered (more operations might be required
* for numerical stability in an actual implementation).
* *
* @return size_t Number of FP arithmetic operations. * @return size_t Number of FP arithmetic operations.
*/ */
...@@ -123,6 +139,8 @@ public: ...@@ -123,6 +139,8 @@ public:
* operation counted as 2 arithmetic operations. MAC can be INT of FP. * operation counted as 2 arithmetic operations. MAC can be INT of FP.
* Control flow operations (loop counters, index computation...) and memory * Control flow operations (loop counters, index computation...) and memory
* accesses are not included. * accesses are not included.
* A naive implementation is considered (more operations might be required
* for numerical stability in an actual implementation).
* *
* @return size_t Number of MAC operations. * @return size_t Number of MAC operations.
*/ */
...@@ -244,6 +262,70 @@ REGISTRAR(OperatorStats, "ConvDepthWise1D", ConvStats<ConvDepthWise_Op<1>>::crea ...@@ -244,6 +262,70 @@ REGISTRAR(OperatorStats, "ConvDepthWise1D", ConvStats<ConvDepthWise_Op<1>>::crea
REGISTRAR(OperatorStats, "Conv2D", ConvStats<Conv_Op<2>>::create); REGISTRAR(OperatorStats, "Conv2D", ConvStats<Conv_Op<2>>::create);
REGISTRAR(OperatorStats, "ConvDepthWise2D", ConvStats<ConvDepthWise_Op<2>>::create); REGISTRAR(OperatorStats, "ConvDepthWise2D", ConvStats<ConvDepthWise_Op<2>>::create);
template <class OP>
class MaxPoolingStats : public OperatorStats {
public:
MaxPoolingStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<MaxPoolingStats<OP>> create(const Operator& op) {
return std::make_unique<MaxPoolingStats<OP>>(op);
}
size_t getNbCompOps() const override {
const OP& op_ = dynamic_cast<const OP&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
const std::size_t poolSize
= std::accumulate(op_.kernelDims().cbegin(),
op_.kernelDims().cend(),
1,
std::multiplies<size_t>());
const std::size_t outputSize
= std::accumulate(op_.getOutput(0)->dims().cbegin() + 2,
op_.getOutput(0)->dims().cend(),
1,
std::multiplies<size_t>()); // NCHW...
const std::size_t batchSize = op_.getInput(0)->dims()[0]; // NCHW
return batchSize * ((poolSize - 1) * outputSize);
}
};
REGISTRAR(OperatorStats, "MaxPooling1D", MaxPoolingStats<MaxPooling_Op<1>>::create);
REGISTRAR(OperatorStats, "MaxPooling2D", MaxPoolingStats<MaxPooling_Op<2>>::create);
REGISTRAR(OperatorStats, "MaxPooling3D", MaxPoolingStats<MaxPooling_Op<3>>::create);
template <class OP>
class AvgPoolingStats : public OperatorStats {
public:
AvgPoolingStats(const Operator& op) : OperatorStats(op) {}
static std::unique_ptr<AvgPoolingStats<OP>> create(const Operator& op) {
return std::make_unique<AvgPoolingStats<OP>>(op);
}
size_t getNbArithmOps() const override {
const OP& op_ = dynamic_cast<const OP&>(mOp);
AIDGE_ASSERT(op_.dimsForwarded(), "Dims must be forwarded for static analysis");
const std::size_t poolSize
= std::accumulate(op_.kernelDims().cbegin(),
op_.kernelDims().cend(),
1,
std::multiplies<size_t>());
const std::size_t outputSize
= std::accumulate(op_.getOutput(0)->dims().cbegin() + 2,
op_.getOutput(0)->dims().cend(),
1,
std::multiplies<size_t>()); // NCHW...
const std::size_t batchSize = op_.getInput(0)->dims()[0]; // NCHW
// (poolSize - 1) additions + 1 division for each output
return batchSize * (poolSize * outputSize);
}
};
REGISTRAR(OperatorStats, "AvgPooling1D", AvgPoolingStats<AvgPooling_Op<1>>::create);
REGISTRAR(OperatorStats, "AvgPooling2D", AvgPoolingStats<AvgPooling_Op<2>>::create);
REGISTRAR(OperatorStats, "AvgPooling3D", AvgPoolingStats<AvgPooling_Op<3>>::create);
REGISTRAR(OperatorStats, "AvgPooling4D", AvgPoolingStats<AvgPooling_Op<4>>::create);
class FCStats : public OperatorStats { class FCStats : public OperatorStats {
public: public:
FCStats(const Operator& op) : OperatorStats(op) {} FCStats(const Operator& op) : OperatorStats(op) {}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment