diff --git a/aidge_export_cpp/kernels/batchnorm.hpp b/aidge_export_cpp/kernels/batchnorm.hpp index 740ea21e6f66ba338985db4f724a5d57377e1f81..0260d42539a892150f1c13cd2a275bb579cd194a 100644 --- a/aidge_export_cpp/kernels/batchnorm.hpp +++ b/aidge_export_cpp/kernels/batchnorm.hpp @@ -2,7 +2,8 @@ #define __AIDGE_EXPORT_CPP_KERNELS_BATCHNORM__ #include "network/typedefs.hpp" -#include "kernels/rescaling.hpp" +#include "kernels/activation.hpp" + #include <math.h> // WARNING: this kernel only works for 32-bits floating point values @@ -11,7 +12,8 @@ template<int NB_OUTPUTS, int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH, ActivationFunction_T ACTIVATION, typename Input_T, typename Output_T, - typename Param_T> + typename Param_T, + typename Rescaling_T> __attribute__((always_inline)) inline void batchnorm_forward ( const Input_T* __restrict inputs, @@ -20,7 +22,8 @@ void batchnorm_forward ( const Param_T* __restrict variances, const Param_T* __restrict means, const Param_T* __restrict scales, - const double epsilon) + const double epsilon, + const Rescaling_T& __restrict rescaling) { for (unsigned int output = 0; output < NB_OUTPUTS; ++output) { const Output_T var = sqrt(variances[output] + epsilon); @@ -31,7 +34,7 @@ void batchnorm_forward ( const Output_T normalized = (inputs[outputOffset + output] - means[output]) / var; const Output_T sAs = scales[output] * normalized + biases[output]; - outputs[outputOffset + output] = sat<Output_T>(sAs, output, ACTIVATION, NoScaling); + outputs[outputOffset + output] = activation_forward_value<Output_T>(sAs, output, ACTIVATION, rescaling); } } } diff --git a/aidge_export_cpp/operators.py b/aidge_export_cpp/operators.py index 0f6d3c83f19cf0368e15005fe971fea031215070..b7e54727459075af5494769e3047084e0f662c63 100644 --- a/aidge_export_cpp/operators.py +++ b/aidge_export_cpp/operators.py @@ -317,4 +317,24 @@ class SoftmaxCPP(ExportNodeCpp): self.kernels_to_copy = [ str(ROOT / "kernels" / "softmax.hpp"), str(ROOT / "kernels" / "macs.hpp"), - ] \ No newline at end of file + ] + +@ExportLibCpp.register("BatchNorm2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) +class BatchNorm2DCPP(ExportNodeCpp): + def __init__(self, node, mem_info): + super().__init__(node, mem_info) + self.attributes["activation"] = "Linear" + self.attributes["rescaling"] = "NoScaling" + self.attributes["epsilon"] = node.get_operator().attr.epsilon + self.config_template = str( + ROOT / "templates" / "configuration" / "batchnorm_config.jinja") + self.forward_template = str( + ROOT / "templates" / "kernel_forward" / "batchnorm_forward.jinja") + self.include_list = [] + self.kernels_to_copy = [ + str(ROOT / "kernels" / "batchnorm.hpp"), + str(ROOT / "kernels" / "macs.hpp"), + str(ROOT / "kernels" / "activation.hpp"), + str(ROOT / "kernels" / "rescaling.hpp") + ] + diff --git a/aidge_export_cpp/templates/configuration/batchnorm_config.jinja b/aidge_export_cpp/templates/configuration/batchnorm_config.jinja index 701ba7c46e4727eca86fcabf3ed997cab69f4e92..bc01e3b964d3741548bb3249d88aa9034ecaa343 100644 --- a/aidge_export_cpp/templates/configuration/batchnorm_config.jinja +++ b/aidge_export_cpp/templates/configuration/batchnorm_config.jinja @@ -7,5 +7,6 @@ {% include "./_meminfo.jinja" %} #define {{ name|upper }}_ACTIVATION {{ activation }} #define {{ name|upper }}_EPSILON {{ epsilon }} +static const {{ rescaling }} {{ name|upper }}_RESCALING = {}; #endif /* {{ name|upper }}_LAYER_H */ diff --git a/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja b/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja index 5a759b839cd0b04b3b82f8ca4cb8dd1b0201f4f7..05e5154c02986817f1af577fbd4124792a987b85 100644 --- a/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja +++ b/aidge_export_cpp/templates/kernel_forward/batchnorm_forward.jinja @@ -4,6 +4,6 @@ batchnorm_forward<{{ out_name[0]|upper }}_NB_OUTPUTS, {{ out_name[0]|upper }}_OUT_HEIGHT, {{ out_name[0]|upper }}_OUT_WIDTH, {{name|upper}}_ACTIVATION> - ({{in_name[0]}}, {{out_name[0]}}, {{in_name[1]}}, {{in_name[2]}}, {{in_name[3]}}, {{in_name[4]}}, {{name|upper}}_EPSILON); + ({{in_name[0]}}, {{out_name[0]}}, {{in_name[1]}}, {{in_name[2]}}, {{in_name[3]}}, {{in_name[4]}}, {{name|upper}}_EPSILON, {{name|upper}}_RESCALING); {% include "./_save_outputs.jinja" %} {% endfilter %}