diff --git a/.gitignore b/.gitignore index 67ffbefbdc41ea1abebd64602649fb129f2faf07..93bcfd30700409d495c7a6d4eb19c12636afbda8 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,9 @@ dist*/ aidge_export_cpp/_version.py wheelhouse/* +# Temp test folders +aidge_export_cpp/unit_tests/*_temp_test + # Mermaid *.mmd diff --git a/aidge_export_cpp/kernels/batchnorm.hpp b/aidge_export_cpp/kernels/batchnorm.hpp index 740ea21e6f66ba338985db4f724a5d57377e1f81..f05a047511e12f895ef88be0e402b89e5197432b 100644 --- a/aidge_export_cpp/kernels/batchnorm.hpp +++ b/aidge_export_cpp/kernels/batchnorm.hpp @@ -2,16 +2,18 @@ #define __AIDGE_EXPORT_CPP_KERNELS_BATCHNORM__ #include "network/typedefs.hpp" -#include "kernels/rescaling.hpp" +#include "kernels/activation.hpp" + #include <math.h> // WARNING: this kernel only works for 32-bits floating point values -template<int NB_OUTPUTS, +template<int NB_BATCHES, int NB_OUTPUTS, int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH, ActivationFunction_T ACTIVATION, typename Input_T, typename Output_T, - typename Param_T> + typename Param_T, + typename Rescaling_T> __attribute__((always_inline)) inline void batchnorm_forward ( const Input_T* __restrict inputs, @@ -20,18 +22,22 @@ void batchnorm_forward ( const Param_T* __restrict variances, const Param_T* __restrict means, const Param_T* __restrict scales, - const double epsilon) + const double epsilon, + const Rescaling_T& __restrict rescaling) { - for (unsigned int output = 0; output < NB_OUTPUTS; ++output) { - const Output_T var = sqrt(variances[output] + epsilon); + for (unsigned int batch = 0; batch < NB_BATCHES; ++batch) { + for (unsigned int output = 0; output < NB_OUTPUTS; ++output) { + // If the variance is 0, we need to avoid division by 0 + Output_T var = sqrt(variances[output] > 0.0 ? variances[output] + epsilon : epsilon); - for (int oy = 0; oy < OUTPUTS_HEIGHT; ++oy) { - for (int ox = 0; ox < OUTPUTS_WIDTH; ++ox) { - const int outputOffset = OUTPUTS_HEIGHT * oy + ox; + for (int oy = 0; oy < OUTPUTS_HEIGHT; ++oy) { + for (int ox = 0; ox < OUTPUTS_WIDTH; ++ox) { + const int outputOffset = batch * OUTPUTS_WIDTH * OUTPUTS_HEIGHT * NB_OUTPUTS + output * OUTPUTS_WIDTH * OUTPUTS_HEIGHT + OUTPUTS_WIDTH * oy + ox; - const Output_T normalized = (inputs[outputOffset + output] - means[output]) / var; - const Output_T sAs = scales[output] * normalized + biases[output]; - outputs[outputOffset + output] = sat<Output_T>(sAs, output, ACTIVATION, NoScaling); + const Output_T normalized = (inputs[outputOffset] - means[output]) / var; + const Output_T sAs = scales[output] * normalized + biases[output]; + outputs[outputOffset] = activation_forward_value<Output_T>(sAs, output, ACTIVATION, rescaling); + } } } } diff --git a/aidge_export_cpp/kernels/concat.hpp b/aidge_export_cpp/kernels/concat.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dde8c4fc3a9ce9eea5d4ae4cfad35c078f60450d --- /dev/null +++ b/aidge_export_cpp/kernels/concat.hpp @@ -0,0 +1,39 @@ +#ifndef __AIDGE_EXPORT_CPP_KERNELS_CONCAT__ +#define __AIDGE_EXPORT_CPP_KERNELS_CONCAT__ + +template<int AXIS_SIZE_POST, + int AXIS_SIZE_PRE, + unsigned int NB_INPUTS, + typename T> +__attribute__((always_inline)) inline static +void concat_forward ( + const T* const * __restrict inputs, + const unsigned int* __restrict sizes, + T* __restrict output) +{ + unsigned int total_concat_axis_size = 0; + for (unsigned int n = 0; n < NB_INPUTS; ++n) + total_concat_axis_size += sizes[n]; + + for (int i = 0; i < AXIS_SIZE_PRE; ++i) { + // Loop over post-axis (e.g., dims after axis 1) + for (int j = 0; j < AXIS_SIZE_POST; ++j) { + unsigned int axis_offset = 0; + + // Loop over each input tensor + for (unsigned int n = 0; n < NB_INPUTS; ++n) { + for (unsigned int k = 0; k < sizes[n]; ++k) { + const int input_idx = i * sizes[n] * AXIS_SIZE_POST + k * AXIS_SIZE_POST + j; + + output[i * total_concat_axis_size * AXIS_SIZE_POST + (axis_offset + k) * AXIS_SIZE_POST + j] = + inputs[n][input_idx]; + } + + axis_offset += sizes[n]; // move along axis in output + } + } + } + +} + +#endif // __AIDGE_EXPORT_CPP_KERNELS_CONCAT__ \ No newline at end of file diff --git a/aidge_export_cpp/kernels/pad.hpp b/aidge_export_cpp/kernels/pad.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4e83257c1152b1963dd4b0eefc912216a729de7d --- /dev/null +++ b/aidge_export_cpp/kernels/pad.hpp @@ -0,0 +1,51 @@ +#ifndef __AIDGE_EXPORT_CPP_KERNELS_PAD2D__ +#define __AIDGE_EXPORT_CPP_KERNELS_PAD2D__ + +#include "network/typedefs.hpp" +#include "network/utils.hpp" + +// Todo add border value and border type (Reflect, Constant, Wrap...) and add the two missing pad value (bottom and right) + +template<int NB_BATCHES, int NB_CHANNELS, + int CHANNELS_HEIGHT, int CHANNELS_WIDTH, + int NB_OUTPUTS, + int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH, + int PADDING_TOP, + int PADDING_LEFT, + int PADDING_BOTTOM, + int PADDING_RIGHT, + typename Input_T, typename Output_T> +__attribute__((always_inline)) inline +void pad_forward( + double borderValue, + const Input_T* __restrict inputs, + Output_T* __restrict outputs + ) +{ + const unsigned int oySize = CHANNELS_HEIGHT + PADDING_TOP + PADDING_BOTTOM; + const unsigned int oxSize = CHANNELS_WIDTH + PADDING_LEFT + PADDING_RIGHT; + + for (unsigned int batch = 0; batch < NB_BATCHES; ++batch) { + for (unsigned int ch = 0; ch < NB_CHANNELS; ++ch) { + const unsigned int preIndex = batch * NB_CHANNELS * CHANNELS_HEIGHT * CHANNELS_WIDTH + ch * CHANNELS_HEIGHT * CHANNELS_WIDTH; + + for (unsigned int oy = 0; oy < oySize; ++oy) { + for (unsigned int ox = 0; ox < oxSize; ++ox) { + const unsigned int outIndex = batch * NB_CHANNELS * oySize * oxSize + ch * oySize * oxSize + oy * oxSize + ox; + + outputs[outIndex] = borderValue; + + const unsigned int inputX = ox - PADDING_LEFT; + const unsigned int inputY = oy - PADDING_TOP; + + if (inputY >= 0 and inputY < CHANNELS_HEIGHT and inputX >= 0 and inputX < CHANNELS_WIDTH) + { + outputs[outIndex] = inputs[preIndex + inputY * CHANNELS_WIDTH + inputX]; + } + } + } + } + } +} + +#endif // __AIDGE_EXPORT_CPP_KERNELS_PAD2D__ diff --git a/aidge_export_cpp/kernels/pooling.hpp b/aidge_export_cpp/kernels/pooling.hpp index 478b6a58aed45e2bce0ed1683ad113f9c7a8bffb..a86fd4196a9f6e19f45dbdc4f1035c1e94e7d285 100644 --- a/aidge_export_cpp/kernels/pooling.hpp +++ b/aidge_export_cpp/kernels/pooling.hpp @@ -7,7 +7,7 @@ #include <stdexcept> -template<int NB_CHANNELS, +template<int NB_CHANNELS, int CHANNELS_HEIGHT, int CHANNELS_WIDTH, int NB_OUTPUTS, int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH, @@ -17,7 +17,7 @@ template<int NB_CHANNELS, Pooling_T POOLING_TYPE, ActivationFunction_T ACTIVATION, typename Input_T, typename Output_T> -__attribute__((always_inline)) inline +__attribute__((always_inline)) inline void pooling_forward( const Input_T* __restrict inputs, Output_T* __restrict outputs) @@ -32,7 +32,7 @@ void pooling_forward( : max(PADDING_Y - (oy * STRIDE_Y), 0); const int syMax = (PADDING_Y == 0 && OUTPUTS_HEIGHT == OUTPUTS_HEIGHT_NOPAD) ? POOL_HEIGHT - : clamp(CHANNELS_HEIGHT + PADDING_Y - (oy * STRIDE_Y), + : clamp(CHANNELS_HEIGHT + PADDING_Y - (oy * STRIDE_Y), 0, POOL_HEIGHT); const int iy = (oy * STRIDE_Y) - PADDING_Y; @@ -45,7 +45,7 @@ void pooling_forward( const int sxMax = (PADDING_X == 0 && OUTPUTS_WIDTH == OUTPUTS_WIDTH_NOPAD) ? POOL_WIDTH - : clamp(CHANNELS_WIDTH + PADDING_X - (ox * STRIDE_X), + : clamp(CHANNELS_WIDTH + PADDING_X - (ox * STRIDE_X), 0, POOL_WIDTH); const int ix = (ox * STRIDE_X) - PADDING_X; @@ -86,7 +86,7 @@ void pooling_forward( outputs[oOffset + output] = maxVal; } else if (POOLING_TYPE == Average) { - int32_t sum = 0; + Output_T sum = 0; for (int sy = 0; sy < POOL_HEIGHT; ++sy) { if ((PADDING_Y != 0 diff --git a/aidge_export_cpp/kernels/softmax.hpp b/aidge_export_cpp/kernels/softmax.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f5472cf6d807bc2f547e58616943f6e72dccd80e --- /dev/null +++ b/aidge_export_cpp/kernels/softmax.hpp @@ -0,0 +1,53 @@ +#ifndef __AIDGE_EXPORT_CPP_KERNELS_SOFTMAX__ +#define __AIDGE_EXPORT_CPP_KERNELS_SOFTMAX__ + +#include "network/typedefs.hpp" +#include "network/utils.hpp" +#include "kernels/macs.hpp" + +#include <type_traits> +#include <cmath> +#include <algorithm> + +template<int AXIS_SIZE, + int AXIS_SIZE_POST, + int AXIS_SIZE_PRE, + typename Input_T, typename Output_T> +__attribute__((always_inline)) inline +void softmax_forward ( + const Input_T* __restrict inputs, + Output_T* __restrict outputs) +{ + // Iterate over the "pre-axis" and "post-axis" slices. + // For each slice along the axis, compute the maximum value, + // the sum of exponentials, and then write the normalized softmax outputs. + for (int i = 0; i < AXIS_SIZE_PRE; ++i) { + for (int j = 0; j < AXIS_SIZE_POST; ++j) { + // Compute the base index for this slice. + const int baseIdx = i * AXIS_SIZE * AXIS_SIZE_POST + j; + + // Find the maximum value along the axis. + Input_T maxVal = inputs[baseIdx]; + for (int k = 1; k < AXIS_SIZE; ++k) { + const int idx = baseIdx + k * AXIS_SIZE_POST; + maxVal = std::max(maxVal, inputs[idx]); + } + + // Compute the sum of the exponentials along the axis. + Input_T sumExp = 0; + for (int k = 0; k < AXIS_SIZE; ++k) { + const int idx = baseIdx + k * AXIS_SIZE_POST; + outputs[idx] = std::exp(inputs[idx] - maxVal); + sumExp += outputs[idx]; + } + + // Write the softmax values to the output. + for (int k = 0; k < AXIS_SIZE; ++k) { + const int idx = baseIdx + k * AXIS_SIZE_POST; + outputs[idx] /= sumExp; + } + } + } +} + +#endif // __AIDGE_EXPORT_CPP_KERNELS_SOFTMAX__ diff --git a/aidge_export_cpp/operators.py b/aidge_export_cpp/operators.py index 346928f4a84c403df2172311cede8b99fd06eebe..26ca62155401707573d9625ad91a9b63cb1b4d2b 100644 --- a/aidge_export_cpp/operators.py +++ b/aidge_export_cpp/operators.py @@ -73,10 +73,25 @@ class ProducerCPP(ExportNode): # TODO : find a way to remove this dummy exportnode @ExportLibCpp.register("Pad2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.any))) -class Pad_ARMCortexM(ExportNodeCpp): +class PadCPP(ExportNodeCpp): def __init__(self, node, mem_info): - raise NotImplementedError("Pad2D nodes is not implemented") + super().__init__(node, mem_info) + self.attributes["padding"] = node.get_operator().attr.begin_end_borders + self.attributes["border_type"] = node.get_operator().attr.border_type + self.attributes["border_value"] = node.get_operator().attr.border_value + + assert self.attributes["border_type"] == aidge_core.pad_border_type.Constant, ( + f"export Pad2d: border_type == {node.get_operator().attr.border_type} not implemented" + ) + self.config_template = str( + ROOT / "templates" / "configuration" / "pad_config.jinja") + self.forward_template = str( + ROOT / "templates" / "kernel_forward" / "pad_forward.jinja") + self.include_list = [] + self.kernels_to_copy = [ + str(ROOT / "kernels" / "pad.hpp") + ] @ExportLibCpp.register("ReLU", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) class ReLUCPP(ExportNodeCpp): @@ -237,6 +252,20 @@ class MaxPoolCPP(ExportNodeCpp): _setup_pooling(self) +@ExportLibCpp.register("AvgPooling2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) +class AvgPoolCPP(ExportNodeCpp): + def __init__(self, node, mem_info): + super().__init__(node, mem_info) + + # No padding with MaxPooling + # Use PaddedMaxPooling to add padding attribute + self.attributes["padding"] = [0, 0] + self.attributes["pool_type"] = "Average" + self.attributes["activation"] = "Linear" + self.attributes["rescaling"] = "NoScaling" + + _setup_pooling(self) + @ExportLibCpp.register_metaop("PaddedMaxPooling2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) class PaddedMaxPoolCPP(ExportNodeCpp): def __init__(self, node, mem_info): @@ -302,4 +331,117 @@ class TransposeCPP(ExportNodeCpp): self.include_list = [] self.kernels_to_copy = [ str(ROOT / "kernels" / "transpose.hpp") + ] + +@ExportLibCpp.register("Softmax", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) +class SoftmaxCPP(ExportNodeCpp): + def __init__(self, node, mem_info): + super().__init__(node, mem_info) + assert self.node.get_nb_inputs() == 1, ( + f"export softmax: nb_inputs == {self.node.get_nb_inputs()} not implemented" + ) + + tensor = self.operator.get_input(0) + nbDims = len(tensor.dims()) + axis = node.get_operator().attr.axis if node.get_operator().attr.axis >= 0 else node.get_operator().attr.axis + nbDims + + assert axis < nbDims, ( + f"export softmax: attribute axis == {node.get_operator().attr.axis} should be less than {nbDims}" + ) + + postAxisElems = 1 + for i in range(axis + 1, nbDims): + postAxisElems *= tensor.dims()[i] + + preAxisElems = 1 + for i in range(axis): + preAxisElems *= tensor.dims()[i] + + self.attributes["axis_size"] = tensor.dims()[axis] + self.attributes["axis_size_post"] = postAxisElems + self.attributes["axis_size_pre"] = preAxisElems + + self.config_template = str( + ROOT / "templates" / "configuration" / "softmax_config.jinja") + self.forward_template = str( + ROOT / "templates" / "kernel_forward" / "softmax_forward.jinja") + self.include_list = [] + self.kernels_to_copy = [ + str(ROOT / "kernels" / "softmax.hpp"), + str(ROOT / "kernels" / "macs.hpp"), + ] + +@ExportLibCpp.register("BatchNorm2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) +class BatchNorm2DCPP(ExportNodeCpp): + def __init__(self, node, mem_info): + super().__init__(node, mem_info) + self.attributes["activation"] = "Linear" + self.attributes["rescaling"] = "NoScaling" + self.attributes["epsilon"] = node.get_operator().attr.epsilon + self.config_template = str( + ROOT / "templates" / "configuration" / "batchnorm_config.jinja") + self.forward_template = str( + ROOT / "templates" / "kernel_forward" / "batchnorm_forward.jinja") + self.include_list = [] + self.kernels_to_copy = [ + str(ROOT / "kernels" / "batchnorm.hpp"), + str(ROOT / "kernels" / "macs.hpp"), + str(ROOT / "kernels" / "activation.hpp"), + str(ROOT / "kernels" / "rescaling.hpp") + ] + +@ExportLibCpp.register("Concat", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) +class Concat(ExportNodeCpp): + def __init__(self, node, mem_info): + super().__init__(node, mem_info) + assert self.node.get_nb_inputs() >= 1, ( + f"export softmax: nb_inputs == {self.node.get_nb_inputs()} not implemented" + ) + + inputIndex = 0 + + tensor = self.operator.get_input(0) + for idx, _ in enumerate(self.node.inputs()): + if self.operator.get_input(idx) is not None: + tensor = self.operator.get_input(idx) + nbDims = len(tensor.dims()) + axis = node.get_operator().attr.axis if node.get_operator().attr.axis >= 0 else node.get_operator().attr.axis + nbDims + + assert axis < nbDims, ( + f"export softmax: attribute axis == {axis} should be less than {nbDims}" + ) + + postAxisElems = 1 + for i in range(axis + 1, nbDims): + postAxisElems *= tensor.dims()[i] + + preAxisElems = 1 + for i in range(axis): + preAxisElems *= tensor.dims()[i] + + if (inputIndex == 0): + self.attributes["axis_size_post"] = postAxisElems + self.attributes["axis_size_pre"] = preAxisElems + + self.attributes["axis_size"] = [None] * self.attributes["nb_in"] + else: + assert self.attributes["axis_size_post"] == postAxisElems, ( + f"export concat: axis_size_post {self.attributes['axis_size_post']} != {postAxisElems}" + ) + assert self.attributes["axis_size_pre"] == preAxisElems, ( + f"export concat: axis_size_pre {self.attributes['axis_size_pre']} != {preAxisElems}" + ) + + self.attributes["axis_size"][idx] = tensor.dims()[axis] + else: + assert false, ( + f"export concat: input {idx} is None, not implemented") + + inputIndex += 1 + + self.config_template = str(ROOT / "templates" / "configuration" / "concat_config.jinja") + self.forward_template = str(ROOT / "templates" / "kernel_forward" / "concat_forward.jinja") + self.include_list = [] + self.kernels_to_copy = [ + str(ROOT / "kernels" / "concat.hpp"), ] \ No newline at end of file diff --git a/aidge_export_cpp/templates/configuration/_def_io.jinja b/aidge_export_cpp/templates/configuration/_def_io.jinja index 66756cf8f501035f7222272f9c410908f499f06f..f44454769bc66e5d15e93834b28e088525930271 100644 --- a/aidge_export_cpp/templates/configuration/_def_io.jinja +++ b/aidge_export_cpp/templates/configuration/_def_io.jinja @@ -4,6 +4,7 @@ #define {{ in_name[inidx]|upper }}_NB_CHANNELS {{ in_chan[inidx] }} #define {{ in_name[inidx]|upper }}_IN_HEIGHT {{ in_height[inidx] }} #define {{ in_name[inidx]|upper }}_IN_WIDTH {{ in_width[inidx] }} +#define {{ in_name[inidx]|upper }}_IN_BATCH {{ in_batch[inidx] }} {% endfor %} // OUTPUT CONF @@ -11,4 +12,5 @@ #define {{ out_name[outidx]|upper }}_NB_OUTPUTS {{ out_chan[outidx] }} #define {{ out_name[outidx]|upper }}_OUT_HEIGHT {{ out_height[outidx] }} #define {{ out_name[outidx]|upper }}_OUT_WIDTH {{ out_width[outidx] }} +#define {{ out_name[outidx]|upper }}_OUT_BATCH {{ out_batch[outidx] }} {% endfor %} diff --git a/aidge_export_cpp/templates/configuration/batchnorm_config.jinja b/aidge_export_cpp/templates/configuration/batchnorm_config.jinja index 701ba7c46e4727eca86fcabf3ed997cab69f4e92..ae7ef5760a63689d11f6d7369e387b55b7cb3d15 100644 --- a/aidge_export_cpp/templates/configuration/batchnorm_config.jinja +++ b/aidge_export_cpp/templates/configuration/batchnorm_config.jinja @@ -1,11 +1,13 @@ {#- For name header -#} #ifndef {{ name|upper }}_LAYER_H #define {{ name|upper }}_LAYER_H +#include "kernels/rescaling.hpp" {# For layer configuration -#} {% include "./_def_io.jinja" %} {% include "./_meminfo.jinja" %} #define {{ name|upper }}_ACTIVATION {{ activation }} #define {{ name|upper }}_EPSILON {{ epsilon }} +static const {{ rescaling }} {{ name|upper }}_RESCALING = {}; #endif /* {{ name|upper }}_LAYER_H */ diff --git a/aidge_export_cpp/templates/configuration/concat_config.jinja b/aidge_export_cpp/templates/configuration/concat_config.jinja new file mode 100644 index 0000000000000000000000000000000000000000..ea8246db9a315a371e0cacea5d45d07fa2b8f7e8 --- /dev/null +++ b/aidge_export_cpp/templates/configuration/concat_config.jinja @@ -0,0 +1,18 @@ +{#- For name header -#} +#ifndef {{ name|upper }}_LAYER_H +#define {{ name|upper }}_LAYER_H + +{% include "./_def_io.jinja" %} +{% include "./_meminfo.jinja" %} + +// Attributes +#define {{ name|upper }}_NB_INPUTS {{ nb_in }} +#define {{ name|upper }}_AXIS {{ axis }} +{%- for i in range(nb_in) %} +#define {{ name|upper }}_INPUT_{{i}}_SIZE {{ axis_size[i] }} +{%- endfor %} + +#define {{ name|upper }}_AXIS_SIZE_POST {{ axis_size_post }} +#define {{ name|upper }}_AXIS_SIZE_PRE {{ axis_size_pre }} + +#endif /* {{ name|upper }}_LAYER_H */ diff --git a/aidge_export_cpp/templates/configuration/pad_config.jinja b/aidge_export_cpp/templates/configuration/pad_config.jinja new file mode 100644 index 0000000000000000000000000000000000000000..8b21577fe4d6f52ddb36ae796740f265db3d45cc --- /dev/null +++ b/aidge_export_cpp/templates/configuration/pad_config.jinja @@ -0,0 +1,13 @@ +{#- For name header -#} +#ifndef {{ name|upper }}_LAYER_H +#define {{ name|upper }}_LAYER_H +{# For layer configuration -#} +{% include "./_def_io.jinja" %} +{% include "./_meminfo.jinja" %} +#define {{ name|upper }}_PADDING_BOTTOM {{ padding[2] }} +#define {{ name|upper }}_PADDING_RIGHT {{ padding[3] }} +#define {{ name|upper }}_PADDING_TOP {{ padding[0] }} +#define {{ name|upper }}_PADDING_LEFT {{ padding[1] }} +#define {{ name|upper }}_BORDER_VALUE {{ border_value }} + +#endif /* {{ name|upper }}_LAYER_H */ diff --git a/aidge_export_cpp/templates/configuration/softmax_config.jinja b/aidge_export_cpp/templates/configuration/softmax_config.jinja new file mode 100644 index 0000000000000000000000000000000000000000..e9661bc553bfefb5a0fb12be5fe87106ac90e4a9 --- /dev/null +++ b/aidge_export_cpp/templates/configuration/softmax_config.jinja @@ -0,0 +1,14 @@ +{#- For name header -#} +#ifndef {{ name|upper }}_LAYER_H +#define {{ name|upper }}_LAYER_H +{# For layer configuration -#} +{% include "./_def_io.jinja" %} +{% include "./_meminfo.jinja" %} + +{#- Calculate sizes #} +{%- set weights_size = out_chan[0] * in_chan[0] * in_height[0] * in_width[0] %} +#define {{ name|upper }}_AXIS_SIZE {{ axis_size }} +#define {{ name|upper }}_AXIS_SIZE_POST {{ axis_size_post }} +#define {{ name|upper }}_AXIS_SIZE_PRE {{ axis_size_pre }} + +#endif /* {{ name|upper }}_LAYER_H */ diff --git a/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja b/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja index 5a759b839cd0b04b3b82f8ca4cb8dd1b0201f4f7..03fd8e89921bfa27f4eeb33b05a47b40329fa5de 100644 --- a/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja +++ b/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja @@ -1,9 +1,10 @@ {% filter indent(width=4, first=False) %} {% include "./_mem_offset.jinja" %} -batchnorm_forward<{{ out_name[0]|upper }}_NB_OUTPUTS, +batchnorm_forward<{{ out_name[0]|upper }}_OUT_BATCH, + {{ out_name[0]|upper }}_NB_OUTPUTS, {{ out_name[0]|upper }}_OUT_HEIGHT, {{ out_name[0]|upper }}_OUT_WIDTH, {{name|upper}}_ACTIVATION> - ({{in_name[0]}}, {{out_name[0]}}, {{in_name[1]}}, {{in_name[2]}}, {{in_name[3]}}, {{in_name[4]}}, {{name|upper}}_EPSILON); + ({{in_name[0]}}, {{out_name[0]}}, {{in_name[1]}}, {{in_name[2]}}, {{in_name[3]}}, {{in_name[4]}}, {{name|upper}}_EPSILON, {{name|upper}}_RESCALING); {% include "./_save_outputs.jinja" %} {% endfilter %} diff --git a/aidge_export_cpp/templates/kernel_forward/concat_forward.jinja b/aidge_export_cpp/templates/kernel_forward/concat_forward.jinja new file mode 100644 index 0000000000000000000000000000000000000000..7a77e904db6c18f338f93099f4f117c9285bf6fc --- /dev/null +++ b/aidge_export_cpp/templates/kernel_forward/concat_forward.jinja @@ -0,0 +1,22 @@ +{% filter indent(width=4, first=False) %} +{% include "./_mem_offset.jinja" %} +const float* {{ name|upper }}_INPUTS[] = { + {%- for i in range(nb_in) -%} + {{ in_name[i] }}{{ ", " if not loop.last else "" }} + {%- endfor -%} +}; + +unsigned int {{ name|upper }}_SIZES[] = { + {%- for i in range(nb_in) -%} + {{ name|upper }}_INPUT_{{i}}_SIZE{{ ", " if not loop.last else "" }} + {%- endfor -%} +}; + +concat_forward<{{ name|upper }}_AXIS_SIZE_POST, + {{ name|upper }}_AXIS_SIZE_PRE, + {{ nb_in }}, + float> ( + {{ name|upper }}_INPUTS, + {{ name|upper }}_SIZES, + {{ out_name[0] }}); + {% endfilter %} diff --git a/aidge_export_cpp/templates/kernel_forward/pad_forward.jinja b/aidge_export_cpp/templates/kernel_forward/pad_forward.jinja new file mode 100644 index 0000000000000000000000000000000000000000..721418709f589d56723156797d7e45afe1259a7b --- /dev/null +++ b/aidge_export_cpp/templates/kernel_forward/pad_forward.jinja @@ -0,0 +1,16 @@ +{% filter indent(width=4, first=False) %} +{% include "./_mem_offset.jinja" %} +pad_forward<{{ in_name[0]|upper }}_IN_BATCH, + {{ in_name[0]|upper }}_NB_CHANNELS, + {{ in_name[0]|upper }}_IN_HEIGHT, + {{ in_name[0]|upper }}_IN_WIDTH, + {{ out_name[0]|upper }}_NB_OUTPUTS, + {{ out_name[0]|upper }}_OUT_HEIGHT, + {{ out_name[0]|upper }}_OUT_WIDTH, + {{name|upper}}_PADDING_TOP, + {{name|upper}}_PADDING_LEFT, + {{name|upper}}_PADDING_BOTTOM, + {{name|upper}}_PADDING_RIGHT> + ({{name|upper}}_BORDER_VALUE, {{in_name[0]}}, {{out_name[0]}}); +{% include "./_save_outputs.jinja" %} +{% endfilter %} diff --git a/aidge_export_cpp/templates/kernel_forward/softmax_forward.jinja b/aidge_export_cpp/templates/kernel_forward/softmax_forward.jinja new file mode 100644 index 0000000000000000000000000000000000000000..7c8e067f34bb2167544bab017e6b581345ba8bb2 --- /dev/null +++ b/aidge_export_cpp/templates/kernel_forward/softmax_forward.jinja @@ -0,0 +1,8 @@ +{% filter indent(width=4, first=False) %} +{% include "./_mem_offset.jinja" %} +softmax_forward<{{ name|upper }}_AXIS_SIZE, + {{ name|upper }}_AXIS_SIZE_POST, + {{ name|upper }}_AXIS_SIZE_PRE> + ({{in_name[0]}}, {{out_name[0]}}); +{% include "./_save_outputs.jinja" %} +{% endfilter %} diff --git a/aidge_export_cpp/unit_tests/test_export.py b/aidge_export_cpp/unit_tests/test_export.py index d900df83285f9b43a098b00d5b853391e7f97f92..607778d23deda862db73f5908fd1caa6ccc1d95b 100644 --- a/aidge_export_cpp/unit_tests/test_export.py +++ b/aidge_export_cpp/unit_tests/test_export.py @@ -3,9 +3,12 @@ import aidge_core import aidge_backend_cpu import aidge_export_cpp import numpy as np +import operator +from functools import reduce import subprocess import re +import shutil from aidge_core.utils import run_command def initFiller(model): @@ -32,6 +35,32 @@ def initFiller(model): else: pass +def _np_init(shape, dtype=np.float32): + """ + Generates a NumPy array with the given shape, filled with random values between -1 and 1 + with a step of 0.1. + + :param shape: Tuple of dimensions for the array + :param dtype: Data type of the output array (default: np.float32) + :return: A NumPy array with the given shape and dtype + """ + total_elements = reduce(operator.mul, shape, 1) + data = (np.random.randint(0, 21, size=total_elements) - 10) / 10.0 + return data.reshape(shape).astype(dtype) + +def _np_init_ones(shape, default_value=0.01, dtype=np.float32): + """ + Generates a NumPy array with the given shape, filled with random values between -1 and 1 + with a step of 0.1. + + :param shape: Tuple of dimensions for the array + :param dtype: Data type of the output array (default: np.float32) + :return: A NumPy array with the given shape and dtype + """ + total_elements = reduce(operator.mul, shape, 1) + data = np.ones(total_elements) * default_value + return data.reshape(shape).astype(dtype) + class test_operator_export(unittest.TestCase): @@ -43,7 +72,7 @@ class test_operator_export(unittest.TestCase): def tearDown(self): pass - def unit_test_export(self, graph_view, op_name, in_dims): + def unit_test_export(self, graph_view, op_name, in_dims, random_inputs=True, random_weights=True, default_value=0.01): """ TODO: * Handle multiple dataformat @@ -56,14 +85,34 @@ class test_operator_export(unittest.TestCase): 4- Retrieve standard output and using regex to now if the results are the same """ graph_view.compile("cpu", aidge_core.dtype.float32, dims=in_dims) + + for node in graph_view.get_nodes(): + if node.type() == "Producer": + prod_op = node.get_operator() + value = prod_op.get_output(0) + + if (random_weights): + tensor = aidge_core.Tensor(_np_init(value.dims())) + + node.get_operator().set_output(0, tensor) + else: + aidge_core.constant_filler(value, default_value) + + scheduler = aidge_core.SequentialScheduler(graph_view) - in_tensor = [aidge_core.Tensor(np.random.random(in_dim).astype(np.float32)) for in_dim in in_dims] + if (random_inputs): + in_tensor = [aidge_core.Tensor(_np_init(in_dim)) for in_dim in in_dims] + else: + in_tensor = [aidge_core.Tensor(_np_init_ones(in_dim, default_value)) for in_dim in in_dims] + scheduler.forward(data=in_tensor) # Note the convention ``<op_name>_test`` is useful for gitignore to avoid pushing generated export by accident. export_folder = op_name + "_test" + shutil.rmtree(export_folder, ignore_errors=True) + # Export the model in C++ standalone aidge_core.export_utils.scheduler_export( scheduler, @@ -112,6 +161,46 @@ class test_operator_export(unittest.TestCase): self.unit_test_export(model, "FC_flat", [[1, 6, 1, 1]]) + def test_export_softmax(self): + print("Softmax") + model = aidge_core.sequential([ + aidge_core.Softmax(axis=1, name="sf0") + ]) + + self.unit_test_export(model, "Softmax", [[1, 10]]) + + def test_export_softmax_batch(self): + print("SoftmaxBatch") + model = aidge_core.sequential([ + aidge_core.Softmax(axis=1, name="sf0") + ]) + + self.unit_test_export(model, "SoftmaxBatch", [[3, 10]]) + + def test_export_softmax_axis_2(self): + print("SoftmaxAxis2") + model = aidge_core.sequential([ + aidge_core.Softmax(axis=2, name="sf0") + ]) + + self.unit_test_export(model, "SoftmaxAxis2", [[1, 10, 3, 7]]) + + def test_export_softmax_axis_negative(self): + print("SoftmaxAxisNegative") + model = aidge_core.sequential([ + aidge_core.Softmax(axis=-3, name="sf0") + ]) + + self.unit_test_export(model, "SoftmaxAxisNegative", [[1, 10, 3, 7]]) + + def test_export_softmax_axis_0(self): + print("SoftmaxAxis0") + model = aidge_core.sequential([ + aidge_core.Softmax(axis=0, name="sf0") + ]) + + self.unit_test_export(model, "SoftmaxAxis0", [[10]]) + @unittest.skip("Currently this test is failing") def test_export_FC_image_in(self): """Test exporting a FC operator with a HWC input. @@ -122,6 +211,347 @@ class test_operator_export(unittest.TestCase): initFiller(model) self.unit_test_export(model, "FC_img", [[1, 3, 2, 2]]) + def test_export_relu(self): + print("ReLU") + model = aidge_core.sequential([ + aidge_core.ReLU(name="relu0") + ]) + + self.unit_test_export(model, "ReLU", [[1, 10]]) + + def test_export_add(self): + print("Add") + model = aidge_core.sequential([ + aidge_core.Producer([1, 5, 5], name="producer"), + aidge_core.Add(name="add") + ]) + + self.unit_test_export(model, "Add", [[1, 5, 5]]) + + def test_export_add_larger(self): + print("AddLarger") + model = aidge_core.sequential([ + aidge_core.Producer([1, 7, 5], name="producer"), + aidge_core.Add(name="add") + ]) + + self.unit_test_export(model, "Add", [[1, 7, 5]]) + + def test_export_add_higher(self): + print("AddHigher") + model = aidge_core.sequential([ + aidge_core.Producer([1, 5, 7], name="producer"), + aidge_core.Add(name="add") + ]) + + self.unit_test_export(model, "Add", [[1, 5, 7]]) + + # "Broadcast not supported yet in export operator" + @unittest.expectedFailure + def test_export_add_simple_broadcast(self): + print("AddSimpleBroadcast") + model = aidge_core.sequential([ + aidge_core.Producer([1, 1, 5], name="producer"), + aidge_core.Add(name="add") + ]) + + self.unit_test_export(model, "AddSimpleBroadcast", [[1, 7, 5]]) + + # "Broadcast not supported yet in export operator" + @unittest.expectedFailure + def test_export_add_double_broadcast(self): + print("AddDoubleBroadcast") + model = aidge_core.sequential([ + aidge_core.Producer([1, 1, 7], name="producer"), + aidge_core.Add(name="add") + ]) + + self.unit_test_export(model, "AddDoubleBroadcast", [[1, 5, 1]]) + + def test_export_sub(self): + print("Sub") + model = aidge_core.sequential([ + aidge_core.Producer([1, 5, 5], name="producer"), + aidge_core.Sub(name="sub") + ]) + + self.unit_test_export(model, "Sub", [[1, 5, 5]]) + + def test_export_sub_larger(self): + print("SubLarger") + model = aidge_core.sequential([ + aidge_core.Producer([1, 7, 5], name="producer"), + aidge_core.Sub(name="sub") + ]) + + self.unit_test_export(model, "Sub", [[1, 7, 5]]) + + def test_export_sub_higher(self): + print("SubHigher") + model = aidge_core.sequential([ + aidge_core.Producer([1, 5, 7], name="producer"), + aidge_core.Sub(name="sub") + ]) + + self.unit_test_export(model, "Sub", [[1, 5, 7]]) + + # "Broadcast not supported yet in export operator" + @unittest.expectedFailure + def test_export_sub_simple_broadcast(self): + print("SubSimpleBroadcast") + model = aidge_core.sequential([ + aidge_core.Producer([1, 1, 5], name="producer"), + aidge_core.Sub(name="sub") + ]) + + self.unit_test_export(model, "SubSimpleBroadcast", [[1, 7, 5]]) + + # "Broadcast not supported yet in export operator" + @unittest.expectedFailure + def test_export_sub_double_broadcast(self): + print("SubDoubleBroadcast") + model = aidge_core.sequential([ + aidge_core.Producer([1, 1, 7], name="producer"), + aidge_core.Sub(name="sub") + ]) + + self.unit_test_export(model, "SubDoubleBroadcast", [[1, 5, 1]]) + + def test_export_mul(self): + print("Mul") + model = aidge_core.sequential([ + aidge_core.Producer([1, 5, 5], name="producer"), + aidge_core.Mul(name="mul") + ]) + + self.unit_test_export(model, "Mul", [[1, 5, 5]]) + + def test_export_mul_larger(self): + print("MulLarger") + model = aidge_core.sequential([ + aidge_core.Producer([1, 7, 5], name="producer"), + aidge_core.Mul(name="mul") + ]) + + self.unit_test_export(model, "Mul", [[1, 7, 5]]) + + def test_export_mul_higher(self): + print("MulHigher") + model = aidge_core.sequential([ + aidge_core.Producer([1, 5, 7], name="producer"), + aidge_core.Mul(name="mul") + ]) + + self.unit_test_export(model, "Mul", [[1, 5, 7]]) + + # "Broadcast not supported yet in export operator" + @unittest.expectedFailure + def test_export_mul_simple_broadcast(self): + print("MulSimpleBroadcast") + model = aidge_core.sequential([ + aidge_core.Producer([1, 1, 5], name="producer"), + aidge_core.Mul(name="mul") + ]) + + self.unit_test_export(model, "MulSimpleBroadcast", [[1, 7, 5]]) + + # "Broadcast not supported yet in export operator" + @unittest.expectedFailure + def test_export_mul_double_broadcast(self): + print("MulDoubleBroadcast") + model = aidge_core.sequential([ + aidge_core.Producer([1, 1, 7], name="producer"), + aidge_core.Mul(name="mul") + ]) + + self.unit_test_export(model, "MulDoubleBroadcast", [[1, 5, 1]]) + + def test_export_mul_batch(self): + print("MulBatch") + model = aidge_core.sequential([ + aidge_core.Producer([3, 5, 7], name="producer"), + aidge_core.Mul(name="mul") + ]) + + self.unit_test_export(model, "MulBatch", [[3, 5, 7]]) + + def test_export_concat(self): + print("Concat") + model = aidge_core.sequential([ + aidge_core.Producer([1, 5, 7], name="producer"), + aidge_core.Concat(nb_inputs=2, axis=1, name="concat") + ]) + + self.unit_test_export(model, "Concat", [[1, 5, 7]]) + + def test_export_concat_axis_2(self): + print("ConcatAxis2") + model = aidge_core.sequential([ + aidge_core.Producer([1, 5, 7], name="producer"), + aidge_core.Concat(nb_inputs=2, axis=2, name="concat") + ]) + + self.unit_test_export(model, "ConcatAxis2", [[1, 5, 7]]) + + def test_export_concat_axis_negative(self): + print("ConcatAxisNegative") + model = aidge_core.sequential([ + aidge_core.Producer([1, 5, 7], name="producer"), + aidge_core.Concat(nb_inputs=2, axis=-2, name="concat") + ]) + + self.unit_test_export(model, "ConcatAxisNegative", [[1, 5, 7]]) + + def test_export_conv2D(self): + print("Conv2D") + model = aidge_core.sequential([ + aidge_core.Conv2D(in_channels=3, out_channels=3, kernel_dims=(3, 3), name="conv") + ]) + + self.unit_test_export(model, "Conv2D", [[1, 3, 12, 12]], False, False) + + def test_export_max_pooling(self): + print("MaxPooling2D") + model = aidge_core.sequential([ + aidge_core.MaxPooling2D(kernel_dims=(3, 3), name="max_pool") + ]) + + self.unit_test_export(model, "MaxPooling2D", [[1, 2, 12, 12]], False, False) + + def test_export_avg_pooling(self): + print("AvgPooling2D") + model = aidge_core.sequential([ + aidge_core.AvgPooling2D(kernel_dims=(3, 3), name="avg_pool") + ]) + + self.unit_test_export(model, "AvgPooling2D", [[1, 2, 12, 12]], False, False) + + def test_export_pad2D(self): + print("Pad2D") + model = aidge_core.sequential([ + aidge_core.Pad2D((1, 1, 1, 1), name="pad2d") + ]) + + self.unit_test_export(model, "Pad2D", [[1, 1, 11, 11]]) + + def test_export_pad2D_larger(self): + print("Pad2DLarger") + model = aidge_core.sequential([ + aidge_core.Pad2D((1, 3, 1, 3), name="pad2d") + ]) + + self.unit_test_export(model, "Pad2DLarger", [[1, 1, 7, 11]]) + + def test_export_pad2D_higher(self): + print("Pad2DHigher") + model = aidge_core.sequential([ + aidge_core.Pad2D((3, 1, 3, 1), name="pad2d") + ]) + + self.unit_test_export(model, "Pad2DHigher", [[1, 1, 11, 7]]) + + def test_export_pad2D_mismatch(self): + print("Pad2DMismatch") + model = aidge_core.sequential([ + aidge_core.Pad2D((1, 3, 5, 7), name="pad2d") + ]) + + self.unit_test_export(model, "Pad2DMismatch", [[3, 5, 11, 7]]) + + def test_export_pad2D_denser(self): + print("Pad2DDenser") + model = aidge_core.sequential([ + aidge_core.Pad2D((3, 3, 3, 3), name="pad2d") + ]) + + self.unit_test_export(model, "Pad2DDenser", [[1, 5, 7, 11]]) + + def test_export_pad2D_with_bigger_batch_size(self): + print("Pad2DBiggerBatchSize") + model = aidge_core.sequential([ + aidge_core.Pad2D((1, 1, 1, 1), name="pad2d") + ]) + + self.unit_test_export(model, "Pad2DBiggerBatchSize", [[3, 5, 7, 11]]) + + @unittest.expectedFailure + def test_export_pad2D_not_constant(self): + print("Pad2DNotConstant") + model = aidge_core.sequential([ + aidge_core.Pad2D((3, 3, 3, 3), border_type=aidge_core.pad_border_type.Wrap, name="pad2d") + ]) + + self.unit_test_export(model, "Pad2DNotConstant", [[1, 5, 7, 11]]) + + def test_export_batchnorm2D(self): + print("BatchNormalization2D") + model = aidge_core.sequential([ + aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn") + ]) + + self.unit_test_export(model, "BatchNorm2D", [[1, 1, 5, 5]], False, False) + + def test_export_batchnorm2D_Larger(self): + print("BatchNormalization2DLarger") + model = aidge_core.sequential([ + aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn") + ]) + + self.unit_test_export(model, "BatchNorm2DLarger", [[1, 1, 5, 7]], False, False) + + def test_export_batchnorm2D_Higher(self): + print("BatchNormalization2DHigher") + model = aidge_core.sequential([ + aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn") + ]) + + self.unit_test_export(model, "BatchNorm2DHigher", [[1, 1, 7, 5]], False, False) + + def test_export_batchnorm2D_Denser(self): + print("BatchNormalization2DDenser") + model = aidge_core.sequential([ + aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn") + ]) + + self.unit_test_export(model, "BatchNorm2DDenser", [[1, 3, 5, 7]], False, False) + + def test_export_batchnorm2D_with_bigger_batch_size(self): + print("BatchNormalization2DBiggerBatchSize") + model = aidge_core.sequential([ + aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn") + ]) + + self.unit_test_export(model, "BatchNormalization2DBiggerBatchSize", [[4, 3, 5, 7]], False, False) + + + def test_export_batchnorm2D_Larger(self): + print("BatchNormalization2DLarger") + model = aidge_core.sequential([ + aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn") + ]) + + self.unit_test_export(model, "BatchNorm2DLarger", [[1, 1, 5, 7]], False, False) + + def test_export_batchnorm2D_Higher(self): + print("BatchNormalization2DHigher") + model = aidge_core.sequential([ + aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn") + ]) + + self.unit_test_export(model, "BatchNorm2DHigher", [[1, 1, 7, 5]], False, False) + + def test_export_batchnorm2D_Denser(self): + print("BatchNormalization2DDenser") + model = aidge_core.sequential([ + aidge_core.BatchNorm2D(nb_features=10, epsilon=2e-5, name="bn") + ]) + + self.unit_test_export(model, "BatchNorm2DDenser", [[1, 3, 5, 7]], False, False) + + + def test_export_cpp(self): + print("Export test to do") + def test_export_Conv(self): model = aidge_core.sequential([ aidge_core.Conv2D(1, 1, [3, 3], name="InputNode")