Skip to content
Snippets Groups Projects
Commit 203fee0d authored by Gallas Gaye's avatar Gallas Gaye
Browse files

fix: Batch support for batchnorm

parent d5659f7d
No related branches found
No related tags found
2 merge requests!39Update 0.2.1 -> 0.3.0,!36feat: Add missing operators for AIDGE model benchmarking
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
// WARNING: this kernel only works for 32-bits floating point values // WARNING: this kernel only works for 32-bits floating point values
template<int NB_OUTPUTS, template<int NB_BATCHES, int NB_OUTPUTS,
int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH, int OUTPUTS_HEIGHT, int OUTPUTS_WIDTH,
ActivationFunction_T ACTIVATION, ActivationFunction_T ACTIVATION,
typename Input_T, typename Output_T, typename Input_T, typename Output_T,
...@@ -25,23 +25,19 @@ void batchnorm_forward ( ...@@ -25,23 +25,19 @@ void batchnorm_forward (
const double epsilon, const double epsilon,
const Rescaling_T& __restrict rescaling) const Rescaling_T& __restrict rescaling)
{ {
for (unsigned int output = 0; output < NB_OUTPUTS; ++output) { for (unsigned int batch = 0; batch < NB_BATCHES; ++batch) {
// If the variance is 0, we need to avoid division by 0 for (unsigned int output = 0; output < NB_OUTPUTS; ++output) {
Output_T var = epsilon; // If the variance is 0, we need to avoid division by 0
Output_T var = sqrt(variances[output] > 0.0 ? variances[output] + epsilon : epsilon);
// If the variance is negative, we need to set it to 0 to avoid a sqrt of a negative number
if (variances[output] > 0.0) for (int oy = 0; oy < OUTPUTS_HEIGHT; ++oy) {
{ for (int ox = 0; ox < OUTPUTS_WIDTH; ++ox) {
var = sqrt(variances[output] + epsilon); const int outputOffset = batch * OUTPUTS_WIDTH * OUTPUTS_HEIGHT * NB_OUTPUTS + output * OUTPUTS_WIDTH * OUTPUTS_HEIGHT + OUTPUTS_WIDTH * oy + ox;
}
const Output_T normalized = (inputs[outputOffset] - means[output]) / var;
for (int oy = 0; oy < OUTPUTS_HEIGHT; ++oy) { const Output_T sAs = scales[output] * normalized + biases[output];
for (int ox = 0; ox < OUTPUTS_WIDTH; ++ox) { outputs[outputOffset] = activation_forward_value<Output_T>(sAs, output, ACTIVATION, rescaling);
const int outputOffset = OUTPUTS_WIDTH * oy + ox; }
const Output_T normalized = (inputs[outputOffset + output] - means[output]) / var;
const Output_T sAs = scales[output] * normalized + biases[output];
outputs[outputOffset + output] = activation_forward_value<Output_T>(sAs, output, ACTIVATION, rescaling);
} }
} }
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#define {{ in_name[inidx]|upper }}_NB_CHANNELS {{ in_chan[inidx] }} #define {{ in_name[inidx]|upper }}_NB_CHANNELS {{ in_chan[inidx] }}
#define {{ in_name[inidx]|upper }}_IN_HEIGHT {{ in_height[inidx] }} #define {{ in_name[inidx]|upper }}_IN_HEIGHT {{ in_height[inidx] }}
#define {{ in_name[inidx]|upper }}_IN_WIDTH {{ in_width[inidx] }} #define {{ in_name[inidx]|upper }}_IN_WIDTH {{ in_width[inidx] }}
#define {{ in_name[inidx]|upper }}_IN_BATCH {{ in_batch[inidx] }}
{% endfor %} {% endfor %}
// OUTPUT CONF // OUTPUT CONF
...@@ -11,4 +12,5 @@ ...@@ -11,4 +12,5 @@
#define {{ out_name[outidx]|upper }}_NB_OUTPUTS {{ out_chan[outidx] }} #define {{ out_name[outidx]|upper }}_NB_OUTPUTS {{ out_chan[outidx] }}
#define {{ out_name[outidx]|upper }}_OUT_HEIGHT {{ out_height[outidx] }} #define {{ out_name[outidx]|upper }}_OUT_HEIGHT {{ out_height[outidx] }}
#define {{ out_name[outidx]|upper }}_OUT_WIDTH {{ out_width[outidx] }} #define {{ out_name[outidx]|upper }}_OUT_WIDTH {{ out_width[outidx] }}
#define {{ out_name[outidx]|upper }}_OUT_BATCH {{ out_batch[outidx] }}
{% endfor %} {% endfor %}
{% filter indent(width=4, first=False) %} {% filter indent(width=4, first=False) %}
{% include "./_mem_offset.jinja" %} {% include "./_mem_offset.jinja" %}
batchnorm_forward<{{ out_name[0]|upper }}_NB_OUTPUTS, batchnorm_forward<{{ out_name[0]|upper }}_OUT_BATCH,
{{ out_name[0]|upper }}_NB_OUTPUTS,
{{ out_name[0]|upper }}_OUT_HEIGHT, {{ out_name[0]|upper }}_OUT_HEIGHT,
{{ out_name[0]|upper }}_OUT_WIDTH, {{ out_name[0]|upper }}_OUT_WIDTH,
{{name|upper}}_ACTIVATION> {{name|upper}}_ACTIVATION>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment