Skip to content
Snippets Groups Projects
Commit e8b1c141 authored by Axel Farrugia's avatar Axel Farrugia
Browse files

Merge branch 'benchmark_operator' into 'dev'

feat: Add missing operators for AIDGE model benchmarking

See merge request !36
parents 6a5441df 838ac45d
No related branches found
No related tags found
2 merge requests!39Update 0.2.1 -> 0.3.0,!36feat: Add missing operators for AIDGE model benchmarking
Pipeline #71076 failed
Showing
with 843 additions and 23 deletions
...@@ -14,6 +14,9 @@ dist*/ ...@@ -14,6 +14,9 @@ dist*/
aidge_export_cpp/_version.py aidge_export_cpp/_version.py
wheelhouse/* wheelhouse/*
# Temp test folders
aidge_export_cpp/unit_tests/*_temp_test
# Mermaid # Mermaid
*.mmd *.mmd
......
...@@ -2,16 +2,18 @@ ...@@ -2,16 +2,18 @@
#define __AIDGE_EXPORT_CPP_KERNELS_BATCHNORM__ #define __AIDGE_EXPORT_CPP_KERNELS_BATCHNORM__
#include "network/typedefs.hpp" #include "network/typedefs.hpp"
#include "kernels/rescaling.hpp" #include "kernels/activation.hpp"
#include <math.h> #include <math.h>
// WARNING: this kernel only works for 32-bits floating point values // WARNING: this kernel only works for 32-bits floating point values
template<int NB_OUTPUTS, template<int NB_BATCHES, int NB_OUTPUTS,
int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH, int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH,
ActivationFunction_T ACTIVATION, ActivationFunction_T ACTIVATION,
typename Input_T, typename Output_T, typename Input_T, typename Output_T,
typename Param_T> typename Param_T,
typename Rescaling_T>
__attribute__((always_inline)) inline __attribute__((always_inline)) inline
void batchnorm_forward ( void batchnorm_forward (
const Input_T* __restrict inputs, const Input_T* __restrict inputs,
...@@ -20,18 +22,22 @@ void batchnorm_forward ( ...@@ -20,18 +22,22 @@ void batchnorm_forward (
const Param_T* __restrict variances, const Param_T* __restrict variances,
const Param_T* __restrict means, const Param_T* __restrict means,
const Param_T* __restrict scales, const Param_T* __restrict scales,
const double epsilon) const double epsilon,
const Rescaling_T& __restrict rescaling)
{ {
for (unsigned int output = 0; output < NB_OUTPUTS; ++output) { for (unsigned int batch = 0; batch < NB_BATCHES; ++batch) {
const Output_T var = sqrt(variances[output] + epsilon); for (unsigned int output = 0; output < NB_OUTPUTS; ++output) {
// If the variance is 0, we need to avoid division by 0
Output_T var = sqrt(variances[output] > 0.0 ? variances[output] + epsilon : epsilon);
for (int oy = 0; oy < OUTPUTS_HEIGHT; ++oy) { for (int oy = 0; oy < OUTPUTS_HEIGHT; ++oy) {
for (int ox = 0; ox < OUTPUTS_WIDTH; ++ox) { for (int ox = 0; ox < OUTPUTS_WIDTH; ++ox) {
const int outputOffset = OUTPUTS_HEIGHT * oy + ox; const int outputOffset = batch * OUTPUTS_WIDTH * OUTPUTS_HEIGHT * NB_OUTPUTS + output * OUTPUTS_WIDTH * OUTPUTS_HEIGHT + OUTPUTS_WIDTH * oy + ox;
const Output_T normalized = (inputs[outputOffset + output] - means[output]) / var; const Output_T normalized = (inputs[outputOffset] - means[output]) / var;
const Output_T sAs = scales[output] * normalized + biases[output]; const Output_T sAs = scales[output] * normalized + biases[output];
outputs[outputOffset + output] = sat<Output_T>(sAs, output, ACTIVATION, NoScaling); outputs[outputOffset] = activation_forward_value<Output_T>(sAs, output, ACTIVATION, rescaling);
}
} }
} }
} }
......
#ifndef __AIDGE_EXPORT_CPP_KERNELS_CONCAT__
#define __AIDGE_EXPORT_CPP_KERNELS_CONCAT__
template<int AXIS_SIZE_POST,
int AXIS_SIZE_PRE,
unsigned int NB_INPUTS,
typename T>
__attribute__((always_inline)) inline static
void concat_forward (
const T* const * __restrict inputs,
const unsigned int* __restrict sizes,
T* __restrict output)
{
unsigned int total_concat_axis_size = 0;
for (unsigned int n = 0; n < NB_INPUTS; ++n)
total_concat_axis_size += sizes[n];
for (int i = 0; i < AXIS_SIZE_PRE; ++i) {
// Loop over post-axis (e.g., dims after axis 1)
for (int j = 0; j < AXIS_SIZE_POST; ++j) {
unsigned int axis_offset = 0;
// Loop over each input tensor
for (unsigned int n = 0; n < NB_INPUTS; ++n) {
for (unsigned int k = 0; k < sizes[n]; ++k) {
const int input_idx = i * sizes[n] * AXIS_SIZE_POST + k * AXIS_SIZE_POST + j;
output[i * total_concat_axis_size * AXIS_SIZE_POST + (axis_offset + k) * AXIS_SIZE_POST + j] =
inputs[n][input_idx];
}
axis_offset += sizes[n]; // move along axis in output
}
}
}
}
#endif // __AIDGE_EXPORT_CPP_KERNELS_CONCAT__
\ No newline at end of file
#ifndef __AIDGE_EXPORT_CPP_KERNELS_PAD2D__
#define __AIDGE_EXPORT_CPP_KERNELS_PAD2D__
#include "network/typedefs.hpp"
#include "network/utils.hpp"
// Todo add border value and border type (Reflect, Constant, Wrap...) and add the two missing pad value (bottom and right)
template<int NB_BATCHES, int NB_CHANNELS,
int CHANNELS_HEIGHT, int CHANNELS_WIDTH,
int NB_OUTPUTS,
int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH,
int PADDING_TOP,
int PADDING_LEFT,
int PADDING_BOTTOM,
int PADDING_RIGHT,
typename Input_T, typename Output_T>
__attribute__((always_inline)) inline
void pad_forward(
double borderValue,
const Input_T* __restrict inputs,
Output_T* __restrict outputs
)
{
const unsigned int oySize = CHANNELS_HEIGHT + PADDING_TOP + PADDING_BOTTOM;
const unsigned int oxSize = CHANNELS_WIDTH + PADDING_LEFT + PADDING_RIGHT;
for (unsigned int batch = 0; batch < NB_BATCHES; ++batch) {
for (unsigned int ch = 0; ch < NB_CHANNELS; ++ch) {
const unsigned int preIndex = batch * NB_CHANNELS * CHANNELS_HEIGHT * CHANNELS_WIDTH + ch * CHANNELS_HEIGHT * CHANNELS_WIDTH;
for (unsigned int oy = 0; oy < oySize; ++oy) {
for (unsigned int ox = 0; ox < oxSize; ++ox) {
const unsigned int outIndex = batch * NB_CHANNELS * oySize * oxSize + ch * oySize * oxSize + oy * oxSize + ox;
outputs[outIndex] = borderValue;
const unsigned int inputX = ox - PADDING_LEFT;
const unsigned int inputY = oy - PADDING_TOP;
if (inputY >= 0 and inputY < CHANNELS_HEIGHT and inputX >= 0 and inputX < CHANNELS_WIDTH)
{
outputs[outIndex] = inputs[preIndex + inputY * CHANNELS_WIDTH + inputX];
}
}
}
}
}
}
#endif // __AIDGE_EXPORT_CPP_KERNELS_PAD2D__
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <stdexcept> #include <stdexcept>
template<int NB_CHANNELS, template<int NB_CHANNELS,
int CHANNELS_HEIGHT, int CHANNELS_WIDTH, int CHANNELS_HEIGHT, int CHANNELS_WIDTH,
int NB_OUTPUTS, int NB_OUTPUTS,
int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH, int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH,
...@@ -17,7 +17,7 @@ template<int NB_CHANNELS, ...@@ -17,7 +17,7 @@ template<int NB_CHANNELS,
Pooling_T POOLING_TYPE, Pooling_T POOLING_TYPE,
ActivationFunction_T ACTIVATION, ActivationFunction_T ACTIVATION,
typename Input_T, typename Output_T> typename Input_T, typename Output_T>
__attribute__((always_inline)) inline __attribute__((always_inline)) inline
void pooling_forward( void pooling_forward(
const Input_T* __restrict inputs, const Input_T* __restrict inputs,
Output_T* __restrict outputs) Output_T* __restrict outputs)
...@@ -32,7 +32,7 @@ void pooling_forward( ...@@ -32,7 +32,7 @@ void pooling_forward(
: max(PADDING_Y - (oy * STRIDE_Y), 0); : max(PADDING_Y - (oy * STRIDE_Y), 0);
const int syMax = (PADDING_Y == 0 const int syMax = (PADDING_Y == 0
&& OUTPUTS_HEIGHT == OUTPUTS_HEIGHT_NOPAD) ? POOL_HEIGHT && OUTPUTS_HEIGHT == OUTPUTS_HEIGHT_NOPAD) ? POOL_HEIGHT
: clamp(CHANNELS_HEIGHT + PADDING_Y - (oy * STRIDE_Y), : clamp(CHANNELS_HEIGHT + PADDING_Y - (oy * STRIDE_Y),
0, POOL_HEIGHT); 0, POOL_HEIGHT);
const int iy = (oy * STRIDE_Y) - PADDING_Y; const int iy = (oy * STRIDE_Y) - PADDING_Y;
...@@ -45,7 +45,7 @@ void pooling_forward( ...@@ -45,7 +45,7 @@ void pooling_forward(
const int sxMax = (PADDING_X == 0 const int sxMax = (PADDING_X == 0
&& OUTPUTS_WIDTH == OUTPUTS_WIDTH_NOPAD) && OUTPUTS_WIDTH == OUTPUTS_WIDTH_NOPAD)
? POOL_WIDTH ? POOL_WIDTH
: clamp(CHANNELS_WIDTH + PADDING_X - (ox * STRIDE_X), : clamp(CHANNELS_WIDTH + PADDING_X - (ox * STRIDE_X),
0, POOL_WIDTH); 0, POOL_WIDTH);
const int ix = (ox * STRIDE_X) - PADDING_X; const int ix = (ox * STRIDE_X) - PADDING_X;
...@@ -86,7 +86,7 @@ void pooling_forward( ...@@ -86,7 +86,7 @@ void pooling_forward(
outputs[oOffset + output] = maxVal; outputs[oOffset + output] = maxVal;
} }
else if (POOLING_TYPE == Average) { else if (POOLING_TYPE == Average) {
int32_t sum = 0; Output_T sum = 0;
for (int sy = 0; sy < POOL_HEIGHT; ++sy) { for (int sy = 0; sy < POOL_HEIGHT; ++sy) {
if ((PADDING_Y != 0 if ((PADDING_Y != 0
......
#ifndef __AIDGE_EXPORT_CPP_KERNELS_SOFTMAX__
#define __AIDGE_EXPORT_CPP_KERNELS_SOFTMAX__
#include "network/typedefs.hpp"
#include "network/utils.hpp"
#include "kernels/macs.hpp"
#include <type_traits>
#include <cmath>
#include <algorithm>
template<int AXIS_SIZE,
int AXIS_SIZE_POST,
int AXIS_SIZE_PRE,
typename Input_T, typename Output_T>
__attribute__((always_inline)) inline
void softmax_forward (
const Input_T* __restrict inputs,
Output_T* __restrict outputs)
{
// Iterate over the "pre-axis" and "post-axis" slices.
// For each slice along the axis, compute the maximum value,
// the sum of exponentials, and then write the normalized softmax outputs.
for (int i = 0; i < AXIS_SIZE_PRE; ++i) {
for (int j = 0; j < AXIS_SIZE_POST; ++j) {
// Compute the base index for this slice.
const int baseIdx = i * AXIS_SIZE * AXIS_SIZE_POST + j;
// Find the maximum value along the axis.
Input_T maxVal = inputs[baseIdx];
for (int k = 1; k < AXIS_SIZE; ++k) {
const int idx = baseIdx + k * AXIS_SIZE_POST;
maxVal = std::max(maxVal, inputs[idx]);
}
// Compute the sum of the exponentials along the axis.
Input_T sumExp = 0;
for (int k = 0; k < AXIS_SIZE; ++k) {
const int idx = baseIdx + k * AXIS_SIZE_POST;
outputs[idx] = std::exp(inputs[idx] - maxVal);
sumExp += outputs[idx];
}
// Write the softmax values to the output.
for (int k = 0; k < AXIS_SIZE; ++k) {
const int idx = baseIdx + k * AXIS_SIZE_POST;
outputs[idx] /= sumExp;
}
}
}
}
#endif // __AIDGE_EXPORT_CPP_KERNELS_SOFTMAX__
...@@ -73,10 +73,25 @@ class ProducerCPP(ExportNode): ...@@ -73,10 +73,25 @@ class ProducerCPP(ExportNode):
# TODO : find a way to remove this dummy exportnode # TODO : find a way to remove this dummy exportnode
@ExportLibCpp.register("Pad2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.any))) @ExportLibCpp.register("Pad2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.any)))
class Pad_ARMCortexM(ExportNodeCpp): class PadCPP(ExportNodeCpp):
def __init__(self, node, mem_info): def __init__(self, node, mem_info):
raise NotImplementedError("Pad2D nodes is not implemented") super().__init__(node, mem_info)
self.attributes["padding"] = node.get_operator().attr.begin_end_borders
self.attributes["border_type"] = node.get_operator().attr.border_type
self.attributes["border_value"] = node.get_operator().attr.border_value
assert self.attributes["border_type"] == aidge_core.pad_border_type.Constant, (
f"export Pad2d: border_type == {node.get_operator().attr.border_type} not implemented"
)
self.config_template = str(
ROOT / "templates" / "configuration" / "pad_config.jinja")
self.forward_template = str(
ROOT / "templates" / "kernel_forward" / "pad_forward.jinja")
self.include_list = []
self.kernels_to_copy = [
str(ROOT / "kernels" / "pad.hpp")
]
@ExportLibCpp.register("ReLU", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) @ExportLibCpp.register("ReLU", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
class ReLUCPP(ExportNodeCpp): class ReLUCPP(ExportNodeCpp):
...@@ -237,6 +252,20 @@ class MaxPoolCPP(ExportNodeCpp): ...@@ -237,6 +252,20 @@ class MaxPoolCPP(ExportNodeCpp):
_setup_pooling(self) _setup_pooling(self)
@ExportLibCpp.register("AvgPooling2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
class AvgPoolCPP(ExportNodeCpp):
def __init__(self, node, mem_info):
super().__init__(node, mem_info)
# No padding with MaxPooling
# Use PaddedMaxPooling to add padding attribute
self.attributes["padding"] = [0, 0]
self.attributes["pool_type"] = "Average"
self.attributes["activation"] = "Linear"
self.attributes["rescaling"] = "NoScaling"
_setup_pooling(self)
@ExportLibCpp.register_metaop("PaddedMaxPooling2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) @ExportLibCpp.register_metaop("PaddedMaxPooling2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
class PaddedMaxPoolCPP(ExportNodeCpp): class PaddedMaxPoolCPP(ExportNodeCpp):
def __init__(self, node, mem_info): def __init__(self, node, mem_info):
...@@ -302,4 +331,117 @@ class TransposeCPP(ExportNodeCpp): ...@@ -302,4 +331,117 @@ class TransposeCPP(ExportNodeCpp):
self.include_list = [] self.include_list = []
self.kernels_to_copy = [ self.kernels_to_copy = [
str(ROOT / "kernels" / "transpose.hpp") str(ROOT / "kernels" / "transpose.hpp")
]
@ExportLibCpp.register("Softmax", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
class SoftmaxCPP(ExportNodeCpp):
def __init__(self, node, mem_info):
super().__init__(node, mem_info)
assert self.node.get_nb_inputs() == 1, (
f"export softmax: nb_inputs == {self.node.get_nb_inputs()} not implemented"
)
tensor = self.operator.get_input(0)
nbDims = len(tensor.dims())
axis = node.get_operator().attr.axis if node.get_operator().attr.axis >= 0 else node.get_operator().attr.axis + nbDims
assert axis < nbDims, (
f"export softmax: attribute axis == {node.get_operator().attr.axis} should be less than {nbDims}"
)
postAxisElems = 1
for i in range(axis + 1, nbDims):
postAxisElems *= tensor.dims()[i]
preAxisElems = 1
for i in range(axis):
preAxisElems *= tensor.dims()[i]
self.attributes["axis_size"] = tensor.dims()[axis]
self.attributes["axis_size_post"] = postAxisElems
self.attributes["axis_size_pre"] = preAxisElems
self.config_template = str(
ROOT / "templates" / "configuration" / "softmax_config.jinja")
self.forward_template = str(
ROOT / "templates" / "kernel_forward" / "softmax_forward.jinja")
self.include_list = []
self.kernels_to_copy = [
str(ROOT / "kernels" / "softmax.hpp"),
str(ROOT / "kernels" / "macs.hpp"),
]
@ExportLibCpp.register("BatchNorm2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
class BatchNorm2DCPP(ExportNodeCpp):
def __init__(self, node, mem_info):
super().__init__(node, mem_info)
self.attributes["activation"] = "Linear"
self.attributes["rescaling"] = "NoScaling"
self.attributes["epsilon"] = node.get_operator().attr.epsilon
self.config_template = str(
ROOT / "templates" / "configuration" / "batchnorm_config.jinja")
self.forward_template = str(
ROOT / "templates" / "kernel_forward" / "batchnorm_forward.jinja")
self.include_list = []
self.kernels_to_copy = [
str(ROOT / "kernels" / "batchnorm.hpp"),
str(ROOT / "kernels" / "macs.hpp"),
str(ROOT / "kernels" / "activation.hpp"),
str(ROOT / "kernels" / "rescaling.hpp")
]
@ExportLibCpp.register("Concat", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
class Concat(ExportNodeCpp):
def __init__(self, node, mem_info):
super().__init__(node, mem_info)
assert self.node.get_nb_inputs() >= 1, (
f"export softmax: nb_inputs == {self.node.get_nb_inputs()} not implemented"
)
inputIndex = 0
tensor = self.operator.get_input(0)
for idx, _ in enumerate(self.node.inputs()):
if self.operator.get_input(idx) is not None:
tensor = self.operator.get_input(idx)
nbDims = len(tensor.dims())
axis = node.get_operator().attr.axis if node.get_operator().attr.axis >= 0 else node.get_operator().attr.axis + nbDims
assert axis < nbDims, (
f"export softmax: attribute axis == {axis} should be less than {nbDims}"
)
postAxisElems = 1
for i in range(axis + 1, nbDims):
postAxisElems *= tensor.dims()[i]
preAxisElems = 1
for i in range(axis):
preAxisElems *= tensor.dims()[i]
if (inputIndex == 0):
self.attributes["axis_size_post"] = postAxisElems
self.attributes["axis_size_pre"] = preAxisElems
self.attributes["axis_size"] = [None] * self.attributes["nb_in"]
else:
assert self.attributes["axis_size_post"] == postAxisElems, (
f"export concat: axis_size_post {self.attributes['axis_size_post']} != {postAxisElems}"
)
assert self.attributes["axis_size_pre"] == preAxisElems, (
f"export concat: axis_size_pre {self.attributes['axis_size_pre']} != {preAxisElems}"
)
self.attributes["axis_size"][idx] = tensor.dims()[axis]
else:
assert false, (
f"export concat: input {idx} is None, not implemented")
inputIndex += 1
self.config_template = str(ROOT / "templates" / "configuration" / "concat_config.jinja")
self.forward_template = str(ROOT / "templates" / "kernel_forward" / "concat_forward.jinja")
self.include_list = []
self.kernels_to_copy = [
str(ROOT / "kernels" / "concat.hpp"),
] ]
\ No newline at end of file
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#define {{ in_name[inidx]|upper }}_NB_CHANNELS {{ in_chan[inidx] }} #define {{ in_name[inidx]|upper }}_NB_CHANNELS {{ in_chan[inidx] }}
#define {{ in_name[inidx]|upper }}_IN_HEIGHT {{ in_height[inidx] }} #define {{ in_name[inidx]|upper }}_IN_HEIGHT {{ in_height[inidx] }}
#define {{ in_name[inidx]|upper }}_IN_WIDTH {{ in_width[inidx] }} #define {{ in_name[inidx]|upper }}_IN_WIDTH {{ in_width[inidx] }}
#define {{ in_name[inidx]|upper }}_IN_BATCH {{ in_batch[inidx] }}
{% endfor %} {% endfor %}
// OUTPUT CONF // OUTPUT CONF
...@@ -11,4 +12,5 @@ ...@@ -11,4 +12,5 @@
#define {{ out_name[outidx]|upper }}_NB_OUTPUTS {{ out_chan[outidx] }} #define {{ out_name[outidx]|upper }}_NB_OUTPUTS {{ out_chan[outidx] }}
#define {{ out_name[outidx]|upper }}_OUT_HEIGHT {{ out_height[outidx] }} #define {{ out_name[outidx]|upper }}_OUT_HEIGHT {{ out_height[outidx] }}
#define {{ out_name[outidx]|upper }}_OUT_WIDTH {{ out_width[outidx] }} #define {{ out_name[outidx]|upper }}_OUT_WIDTH {{ out_width[outidx] }}
#define {{ out_name[outidx]|upper }}_OUT_BATCH {{ out_batch[outidx] }}
{% endfor %} {% endfor %}
{#- For name header -#} {#- For name header -#}
#ifndef {{ name|upper }}_LAYER_H #ifndef {{ name|upper }}_LAYER_H
#define {{ name|upper }}_LAYER_H #define {{ name|upper }}_LAYER_H
#include "kernels/rescaling.hpp"
{# For layer configuration -#} {# For layer configuration -#}
{% include "./_def_io.jinja" %} {% include "./_def_io.jinja" %}
{% include "./_meminfo.jinja" %} {% include "./_meminfo.jinja" %}
#define {{ name|upper }}_ACTIVATION {{ activation }} #define {{ name|upper }}_ACTIVATION {{ activation }}
#define {{ name|upper }}_EPSILON {{ epsilon }} #define {{ name|upper }}_EPSILON {{ epsilon }}
static const {{ rescaling }} {{ name|upper }}_RESCALING = {};
#endif /* {{ name|upper }}_LAYER_H */ #endif /* {{ name|upper }}_LAYER_H */
{#- For name header -#}
#ifndef {{ name|upper }}_LAYER_H
#define {{ name|upper }}_LAYER_H
{% include "./_def_io.jinja" %}
{% include "./_meminfo.jinja" %}
// Attributes
#define {{ name|upper }}_NB_INPUTS {{ nb_in }}
#define {{ name|upper }}_AXIS {{ axis }}
{%- for i in range(nb_in) %}
#define {{ name|upper }}_INPUT_{{i}}_SIZE {{ axis_size[i] }}
{%- endfor %}
#define {{ name|upper }}_AXIS_SIZE_POST {{ axis_size_post }}
#define {{ name|upper }}_AXIS_SIZE_PRE {{ axis_size_pre }}
#endif /* {{ name|upper }}_LAYER_H */
{#- For name header -#}
#ifndef {{ name|upper }}_LAYER_H
#define {{ name|upper }}_LAYER_H
{# For layer configuration -#}
{% include "./_def_io.jinja" %}
{% include "./_meminfo.jinja" %}
#define {{ name|upper }}_PADDING_BOTTOM {{ padding[2] }}
#define {{ name|upper }}_PADDING_RIGHT {{ padding[3] }}
#define {{ name|upper }}_PADDING_TOP {{ padding[0] }}
#define {{ name|upper }}_PADDING_LEFT {{ padding[1] }}
#define {{ name|upper }}_BORDER_VALUE {{ border_value }}
#endif /* {{ name|upper }}_LAYER_H */
{#- For name header -#}
#ifndef {{ name|upper }}_LAYER_H
#define {{ name|upper }}_LAYER_H
{# For layer configuration -#}
{% include "./_def_io.jinja" %}
{% include "./_meminfo.jinja" %}
{#- Calculate sizes #}
{%- set weights_size = out_chan[0] * in_chan[0] * in_height[0] * in_width[0] %}
#define {{ name|upper }}_AXIS_SIZE {{ axis_size }}
#define {{ name|upper }}_AXIS_SIZE_POST {{ axis_size_post }}
#define {{ name|upper }}_AXIS_SIZE_PRE {{ axis_size_pre }}
#endif /* {{ name|upper }}_LAYER_H */
{% filter indent(width=4, first=False) %} {% filter indent(width=4, first=False) %}
{% include "./_mem_offset.jinja" %} {% include "./_mem_offset.jinja" %}
batchnorm_forward<{{ out_name[0]|upper }}_NB_OUTPUTS, batchnorm_forward<{{ out_name[0]|upper }}_OUT_BATCH,
{{ out_name[0]|upper }}_NB_OUTPUTS,
{{ out_name[0]|upper }}_OUT_HEIGHT, {{ out_name[0]|upper }}_OUT_HEIGHT,
{{ out_name[0]|upper }}_OUT_WIDTH, {{ out_name[0]|upper }}_OUT_WIDTH,
{{name|upper}}_ACTIVATION> {{name|upper}}_ACTIVATION>
({{in_name[0]}}, {{out_name[0]}}, {{in_name[1]}}, {{in_name[2]}}, {{in_name[3]}}, {{in_name[4]}}, {{name|upper}}_EPSILON); ({{in_name[0]}}, {{out_name[0]}}, {{in_name[1]}}, {{in_name[2]}}, {{in_name[3]}}, {{in_name[4]}}, {{name|upper}}_EPSILON, {{name|upper}}_RESCALING);
{% include "./_save_outputs.jinja" %} {% include "./_save_outputs.jinja" %}
{% endfilter %} {% endfilter %}
{% filter indent(width=4, first=False) %}
{% include "./_mem_offset.jinja" %}
const float* {{ name|upper }}_INPUTS[] = {
{%- for i in range(nb_in) -%}
{{ in_name[i] }}{{ ", " if not loop.last else "" }}
{%- endfor -%}
};
unsigned int {{ name|upper }}_SIZES[] = {
{%- for i in range(nb_in) -%}
{{ name|upper }}_INPUT_{{i}}_SIZE{{ ", " if not loop.last else "" }}
{%- endfor -%}
};
concat_forward<{{ name|upper }}_AXIS_SIZE_POST,
{{ name|upper }}_AXIS_SIZE_PRE,
{{ nb_in }},
float> (
{{ name|upper }}_INPUTS,
{{ name|upper }}_SIZES,
{{ out_name[0] }});
{% endfilter %}
{% filter indent(width=4, first=False) %}
{% include "./_mem_offset.jinja" %}
pad_forward<{{ in_name[0]|upper }}_IN_BATCH,
{{ in_name[0]|upper }}_NB_CHANNELS,
{{ in_name[0]|upper }}_IN_HEIGHT,
{{ in_name[0]|upper }}_IN_WIDTH,
{{ out_name[0]|upper }}_NB_OUTPUTS,
{{ out_name[0]|upper }}_OUT_HEIGHT,
{{ out_name[0]|upper }}_OUT_WIDTH,
{{name|upper}}_PADDING_TOP,
{{name|upper}}_PADDING_LEFT,
{{name|upper}}_PADDING_BOTTOM,
{{name|upper}}_PADDING_RIGHT>
({{name|upper}}_BORDER_VALUE, {{in_name[0]}}, {{out_name[0]}});
{% include "./_save_outputs.jinja" %}
{% endfilter %}
{% filter indent(width=4, first=False) %}
{% include "./_mem_offset.jinja" %}
softmax_forward<{{ name|upper }}_AXIS_SIZE,
{{ name|upper }}_AXIS_SIZE_POST,
{{ name|upper }}_AXIS_SIZE_PRE>
({{in_name[0]}}, {{out_name[0]}});
{% include "./_save_outputs.jinja" %}
{% endfilter %}
...@@ -3,9 +3,12 @@ import aidge_core ...@@ -3,9 +3,12 @@ import aidge_core
import aidge_backend_cpu import aidge_backend_cpu
import aidge_export_cpp import aidge_export_cpp
import numpy as np import numpy as np
import operator
from functools import reduce
import subprocess import subprocess
import re import re
import shutil
from aidge_core.utils import run_command from aidge_core.utils import run_command
def initFiller(model): def initFiller(model):
...@@ -32,6 +35,32 @@ def initFiller(model): ...@@ -32,6 +35,32 @@ def initFiller(model):
else: else:
pass pass
def _np_init(shape, dtype=np.float32):
"""
Generates a NumPy array with the given shape, filled with random values between -1 and 1
with a step of 0.1.
:param shape: Tuple of dimensions for the array
:param dtype: Data type of the output array (default: np.float32)
:return: A NumPy array with the given shape and dtype
"""
total_elements = reduce(operator.mul, shape, 1)
data = (np.random.randint(0, 21, size=total_elements) - 10) / 10.0
return data.reshape(shape).astype(dtype)
def _np_init_ones(shape, default_value=0.01, dtype=np.float32):
"""
Generates a NumPy array with the given shape, filled with random values between -1 and 1
with a step of 0.1.
:param shape: Tuple of dimensions for the array
:param dtype: Data type of the output array (default: np.float32)
:return: A NumPy array with the given shape and dtype
"""
total_elements = reduce(operator.mul, shape, 1)
data = np.ones(total_elements) * default_value
return data.reshape(shape).astype(dtype)
class test_operator_export(unittest.TestCase): class test_operator_export(unittest.TestCase):
...@@ -43,7 +72,7 @@ class test_operator_export(unittest.TestCase): ...@@ -43,7 +72,7 @@ class test_operator_export(unittest.TestCase):
def tearDown(self): def tearDown(self):
pass pass
def unit_test_export(self, graph_view, op_name, in_dims): def unit_test_export(self, graph_view, op_name, in_dims, random_inputs=True, random_weights=True, default_value=0.01):
""" """
TODO: TODO:
* Handle multiple dataformat * Handle multiple dataformat
...@@ -56,14 +85,34 @@ class test_operator_export(unittest.TestCase): ...@@ -56,14 +85,34 @@ class test_operator_export(unittest.TestCase):
4- Retrieve standard output and using regex to now if the results are the same 4- Retrieve standard output and using regex to now if the results are the same
""" """
graph_view.compile("cpu", aidge_core.dtype.float32, dims=in_dims) graph_view.compile("cpu", aidge_core.dtype.float32, dims=in_dims)
for node in graph_view.get_nodes():
if node.type() == "Producer":
prod_op = node.get_operator()
value = prod_op.get_output(0)
if (random_weights):
tensor = aidge_core.Tensor(_np_init(value.dims()))
node.get_operator().set_output(0, tensor)
else:
aidge_core.constant_filler(value, default_value)
scheduler = aidge_core.SequentialScheduler(graph_view) scheduler = aidge_core.SequentialScheduler(graph_view)
in_tensor = [aidge_core.Tensor(np.random.random(in_dim).astype(np.float32)) for in_dim in in_dims] if (random_inputs):
in_tensor = [aidge_core.Tensor(_np_init(in_dim)) for in_dim in in_dims]
else:
in_tensor = [aidge_core.Tensor(_np_init_ones(in_dim, default_value)) for in_dim in in_dims]
scheduler.forward(data=in_tensor) scheduler.forward(data=in_tensor)
# Note the convention ``<op_name>_test`` is useful for gitignore to avoid pushing generated export by accident. # Note the convention ``<op_name>_test`` is useful for gitignore to avoid pushing generated export by accident.
export_folder = op_name + "_test" export_folder = op_name + "_test"
shutil.rmtree(export_folder, ignore_errors=True)
# Export the model in C++ standalone # Export the model in C++ standalone
aidge_core.export_utils.scheduler_export( aidge_core.export_utils.scheduler_export(
scheduler, scheduler,
...@@ -112,6 +161,46 @@ class test_operator_export(unittest.TestCase): ...@@ -112,6 +161,46 @@ class test_operator_export(unittest.TestCase):
self.unit_test_export(model, "FC_flat", [[1, 6, 1, 1]]) self.unit_test_export(model, "FC_flat", [[1, 6, 1, 1]])
def test_export_softmax(self):
print("Softmax")
model = aidge_core.sequential([
aidge_core.Softmax(axis=1, name="sf0")
])
self.unit_test_export(model, "Softmax", [[1, 10]])
def test_export_softmax_batch(self):
print("SoftmaxBatch")
model = aidge_core.sequential([
aidge_core.Softmax(axis=1, name="sf0")
])
self.unit_test_export(model, "SoftmaxBatch", [[3, 10]])
def test_export_softmax_axis_2(self):
print("SoftmaxAxis2")
model = aidge_core.sequential([
aidge_core.Softmax(axis=2, name="sf0")
])
self.unit_test_export(model, "SoftmaxAxis2", [[1, 10, 3, 7]])
def test_export_softmax_axis_negative(self):
print("SoftmaxAxisNegative")
model = aidge_core.sequential([
aidge_core.Softmax(axis=-3, name="sf0")
])
self.unit_test_export(model, "SoftmaxAxisNegative", [[1, 10, 3, 7]])
def test_export_softmax_axis_0(self):
print("SoftmaxAxis0")
model = aidge_core.sequential([
aidge_core.Softmax(axis=0, name="sf0")
])
self.unit_test_export(model, "SoftmaxAxis0", [[10]])
@unittest.skip("Currently this test is failing") @unittest.skip("Currently this test is failing")
def test_export_FC_image_in(self): def test_export_FC_image_in(self):
"""Test exporting a FC operator with a HWC input. """Test exporting a FC operator with a HWC input.
...@@ -122,6 +211,347 @@ class test_operator_export(unittest.TestCase): ...@@ -122,6 +211,347 @@ class test_operator_export(unittest.TestCase):
initFiller(model) initFiller(model)
self.unit_test_export(model, "FC_img", [[1, 3, 2, 2]]) self.unit_test_export(model, "FC_img", [[1, 3, 2, 2]])
def test_export_relu(self):
print("ReLU")
model = aidge_core.sequential([
aidge_core.ReLU(name="relu0")
])
self.unit_test_export(model, "ReLU", [[1, 10]])
def test_export_add(self):
print("Add")
model = aidge_core.sequential([
aidge_core.Producer([1, 5, 5], name="producer"),
aidge_core.Add(name="add")
])
self.unit_test_export(model, "Add", [[1, 5, 5]])
def test_export_add_larger(self):
print("AddLarger")
model = aidge_core.sequential([
aidge_core.Producer([1, 7, 5], name="producer"),
aidge_core.Add(name="add")
])
self.unit_test_export(model, "Add", [[1, 7, 5]])
def test_export_add_higher(self):
print("AddHigher")
model = aidge_core.sequential([
aidge_core.Producer([1, 5, 7], name="producer"),
aidge_core.Add(name="add")
])
self.unit_test_export(model, "Add", [[1, 5, 7]])
# "Broadcast not supported yet in export operator"
@unittest.expectedFailure
def test_export_add_simple_broadcast(self):
print("AddSimpleBroadcast")
model = aidge_core.sequential([
aidge_core.Producer([1, 1, 5], name="producer"),
aidge_core.Add(name="add")
])
self.unit_test_export(model, "AddSimpleBroadcast", [[1, 7, 5]])
# "Broadcast not supported yet in export operator"
@unittest.expectedFailure
def test_export_add_double_broadcast(self):
print("AddDoubleBroadcast")
model = aidge_core.sequential([
aidge_core.Producer([1, 1, 7], name="producer"),
aidge_core.Add(name="add")
])
self.unit_test_export(model, "AddDoubleBroadcast", [[1, 5, 1]])
def test_export_sub(self):
print("Sub")
model = aidge_core.sequential([
aidge_core.Producer([1, 5, 5], name="producer"),
aidge_core.Sub(name="sub")
])
self.unit_test_export(model, "Sub", [[1, 5, 5]])
def test_export_sub_larger(self):
print("SubLarger")
model = aidge_core.sequential([
aidge_core.Producer([1, 7, 5], name="producer"),
aidge_core.Sub(name="sub")
])
self.unit_test_export(model, "Sub", [[1, 7, 5]])
def test_export_sub_higher(self):
print("SubHigher")
model = aidge_core.sequential([
aidge_core.Producer([1, 5, 7], name="producer"),
aidge_core.Sub(name="sub")
])
self.unit_test_export(model, "Sub", [[1, 5, 7]])
# "Broadcast not supported yet in export operator"
@unittest.expectedFailure
def test_export_sub_simple_broadcast(self):
print("SubSimpleBroadcast")
model = aidge_core.sequential([
aidge_core.Producer([1, 1, 5], name="producer"),
aidge_core.Sub(name="sub")
])
self.unit_test_export(model, "SubSimpleBroadcast", [[1, 7, 5]])
# "Broadcast not supported yet in export operator"
@unittest.expectedFailure
def test_export_sub_double_broadcast(self):
print("SubDoubleBroadcast")
model = aidge_core.sequential([
aidge_core.Producer([1, 1, 7], name="producer"),
aidge_core.Sub(name="sub")
])
self.unit_test_export(model, "SubDoubleBroadcast", [[1, 5, 1]])
def test_export_mul(self):
print("Mul")
model = aidge_core.sequential([
aidge_core.Producer([1, 5, 5], name="producer"),
aidge_core.Mul(name="mul")
])
self.unit_test_export(model, "Mul", [[1, 5, 5]])
def test_export_mul_larger(self):
print("MulLarger")
model = aidge_core.sequential([
aidge_core.Producer([1, 7, 5], name="producer"),
aidge_core.Mul(name="mul")
])
self.unit_test_export(model, "Mul", [[1, 7, 5]])
def test_export_mul_higher(self):
print("MulHigher")
model = aidge_core.sequential([
aidge_core.Producer([1, 5, 7], name="producer"),
aidge_core.Mul(name="mul")
])
self.unit_test_export(model, "Mul", [[1, 5, 7]])
# "Broadcast not supported yet in export operator"
@unittest.expectedFailure
def test_export_mul_simple_broadcast(self):
print("MulSimpleBroadcast")
model = aidge_core.sequential([
aidge_core.Producer([1, 1, 5], name="producer"),
aidge_core.Mul(name="mul")
])
self.unit_test_export(model, "MulSimpleBroadcast", [[1, 7, 5]])
# "Broadcast not supported yet in export operator"
@unittest.expectedFailure
def test_export_mul_double_broadcast(self):
print("MulDoubleBroadcast")
model = aidge_core.sequential([
aidge_core.Producer([1, 1, 7], name="producer"),
aidge_core.Mul(name="mul")
])
self.unit_test_export(model, "MulDoubleBroadcast", [[1, 5, 1]])
def test_export_mul_batch(self):
print("MulBatch")
model = aidge_core.sequential([
aidge_core.Producer([3, 5, 7], name="producer"),
aidge_core.Mul(name="mul")
])
self.unit_test_export(model, "MulBatch", [[3, 5, 7]])
def test_export_concat(self):
print("Concat")
model = aidge_core.sequential([
aidge_core.Producer([1, 5, 7], name="producer"),
aidge_core.Concat(nb_inputs=2, axis=1, name="concat")
])
self.unit_test_export(model, "Concat", [[1, 5, 7]])
def test_export_concat_axis_2(self):
print("ConcatAxis2")
model = aidge_core.sequential([
aidge_core.Producer([1, 5, 7], name="producer"),
aidge_core.Concat(nb_inputs=2, axis=2, name="concat")
])
self.unit_test_export(model, "ConcatAxis2", [[1, 5, 7]])
def test_export_concat_axis_negative(self):
print("ConcatAxisNegative")
model = aidge_core.sequential([
aidge_core.Producer([1, 5, 7], name="producer"),
aidge_core.Concat(nb_inputs=2, axis=-2, name="concat")
])
self.unit_test_export(model, "ConcatAxisNegative", [[1, 5, 7]])
def test_export_conv2D(self):
print("Conv2D")
model = aidge_core.sequential([
aidge_core.Conv2D(in_channels=3, out_channels=3, kernel_dims=(3, 3), name="conv")
])
self.unit_test_export(model, "Conv2D", [[1, 3, 12, 12]], False, False)
def test_export_max_pooling(self):
print("MaxPooling2D")
model = aidge_core.sequential([
aidge_core.MaxPooling2D(kernel_dims=(3, 3), name="max_pool")
])
self.unit_test_export(model, "MaxPooling2D", [[1, 2, 12, 12]], False, False)
def test_export_avg_pooling(self):
print("AvgPooling2D")
model = aidge_core.sequential([
aidge_core.AvgPooling2D(kernel_dims=(3, 3), name="avg_pool")
])
self.unit_test_export(model, "AvgPooling2D", [[1, 2, 12, 12]], False, False)
def test_export_pad2D(self):
print("Pad2D")
model = aidge_core.sequential([
aidge_core.Pad2D((1, 1, 1, 1), name="pad2d")
])
self.unit_test_export(model, "Pad2D", [[1, 1, 11, 11]])
def test_export_pad2D_larger(self):
print("Pad2DLarger")
model = aidge_core.sequential([
aidge_core.Pad2D((1, 3, 1, 3), name="pad2d")
])
self.unit_test_export(model, "Pad2DLarger", [[1, 1, 7, 11]])
def test_export_pad2D_higher(self):
print("Pad2DHigher")
model = aidge_core.sequential([
aidge_core.Pad2D((3, 1, 3, 1), name="pad2d")
])
self.unit_test_export(model, "Pad2DHigher", [[1, 1, 11, 7]])
def test_export_pad2D_mismatch(self):
print("Pad2DMismatch")
model = aidge_core.sequential([
aidge_core.Pad2D((1, 3, 5, 7), name="pad2d")
])
self.unit_test_export(model, "Pad2DMismatch", [[3, 5, 11, 7]])
def test_export_pad2D_denser(self):
print("Pad2DDenser")
model = aidge_core.sequential([
aidge_core.Pad2D((3, 3, 3, 3), name="pad2d")
])
self.unit_test_export(model, "Pad2DDenser", [[1, 5, 7, 11]])
def test_export_pad2D_with_bigger_batch_size(self):
print("Pad2DBiggerBatchSize")
model = aidge_core.sequential([
aidge_core.Pad2D((1, 1, 1, 1), name="pad2d")
])
self.unit_test_export(model, "Pad2DBiggerBatchSize", [[3, 5, 7, 11]])
@unittest.expectedFailure
def test_export_pad2D_not_constant(self):
print("Pad2DNotConstant")
model = aidge_core.sequential([
aidge_core.Pad2D((3, 3, 3, 3), border_type=aidge_core.pad_border_type.Wrap, name="pad2d")
])
self.unit_test_export(model, "Pad2DNotConstant", [[1, 5, 7, 11]])
def test_export_batchnorm2D(self):
print("BatchNormalization2D")
model = aidge_core.sequential([
aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn")
])
self.unit_test_export(model, "BatchNorm2D", [[1, 1, 5, 5]], False, False)
def test_export_batchnorm2D_Larger(self):
print("BatchNormalization2DLarger")
model = aidge_core.sequential([
aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn")
])
self.unit_test_export(model, "BatchNorm2DLarger", [[1, 1, 5, 7]], False, False)
def test_export_batchnorm2D_Higher(self):
print("BatchNormalization2DHigher")
model = aidge_core.sequential([
aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn")
])
self.unit_test_export(model, "BatchNorm2DHigher", [[1, 1, 7, 5]], False, False)
def test_export_batchnorm2D_Denser(self):
print("BatchNormalization2DDenser")
model = aidge_core.sequential([
aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn")
])
self.unit_test_export(model, "BatchNorm2DDenser", [[1, 3, 5, 7]], False, False)
def test_export_batchnorm2D_with_bigger_batch_size(self):
print("BatchNormalization2DBiggerBatchSize")
model = aidge_core.sequential([
aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn")
])
self.unit_test_export(model, "BatchNormalization2DBiggerBatchSize", [[4, 3, 5, 7]], False, False)
def test_export_batchnorm2D_Larger(self):
print("BatchNormalization2DLarger")
model = aidge_core.sequential([
aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn")
])
self.unit_test_export(model, "BatchNorm2DLarger", [[1, 1, 5, 7]], False, False)
def test_export_batchnorm2D_Higher(self):
print("BatchNormalization2DHigher")
model = aidge_core.sequential([
aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn")
])
self.unit_test_export(model, "BatchNorm2DHigher", [[1, 1, 7, 5]], False, False)
def test_export_batchnorm2D_Denser(self):
print("BatchNormalization2DDenser")
model = aidge_core.sequential([
aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn")
])
self.unit_test_export(model, "BatchNorm2DDenser", [[1, 3, 5, 7]], False, False)
def test_export_cpp(self):
print("Export test to do")
def test_export_Conv(self): def test_export_Conv(self):
model = aidge_core.sequential([ model = aidge_core.sequential([
aidge_core.Conv2D(1, 1, [3, 3], name="InputNode") aidge_core.Conv2D(1, 1, [3, 3], name="InputNode")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment