Skip to content
Snippets Groups Projects
Commit 7230dae2 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Improved elemwise to be fully templated

parent 03009d35
No related branches found
No related tags found
2 merge requests!710.4.0,!55Improved elemwise and matmul to be fully templated
Pipeline #76304 passed with warnings
......@@ -4,9 +4,9 @@
#include "network/typedefs.hpp"
#include "network/activation_utils.hpp"
template<int NB_ELTS, ElemWise_T ELEM_OP,
int INPUT_A_DIMS[], int INPUT_B_DIMS[], int OUTPUT_DIMS[],
int SIZE_DIM_IN_A, int SIZE_DIM_IN_B, int SIZE_DIM_OUT, int OUT_SIZE,
template<int NB_MAT, ElemWise_T ELEM_OP,
int INPUT1_CONT_SIZE, int INPUT2_CONT_SIZE, int OUTPUT_CONT_SIZE,
const int OFFSET_IN1[], const int OFFSET_IN2[],
ActivationFunction_T ACTIVATION,
typename Input_T, typename Output_T, typename Rescaling_T>
__attribute__((always_inline)) inline
......@@ -16,139 +16,35 @@ void elemwise_forward(
const Input_T* __restrict inputs1,
const Input_T* __restrict inputs2)
{
if (std::is_floating_point<Input_T>::value)
{
Input_T val = 0;
int ndim_a[SIZE_DIM_OUT];
int ndim_b[SIZE_DIM_OUT];
for (int i = 0; i < SIZE_DIM_OUT; i++) {
int idx = SIZE_DIM_OUT - SIZE_DIM_IN_A;
ndim_a[i] = (i < idx) ? 1 : INPUT_A_DIMS[i - idx];
}
for (int i = 0; i < SIZE_DIM_OUT; i++) {
int idx = SIZE_DIM_OUT - SIZE_DIM_IN_B;
ndim_b[i] = (i < idx) ? 1 : INPUT_B_DIMS[i - idx];
}
// Find the highest equal dimension
int contiguousIdx = SIZE_DIM_OUT - 1;
for (int i = contiguousIdx; ndim_a[i] == ndim_b[i]; i--) {
contiguousIdx = i;
}
// Compute the highest number of contiguous data for each Tensor
int input0_contiguous_size = 1;
for (int i = contiguousIdx; i < SIZE_DIM_OUT; ++i) {
input0_contiguous_size *= ndim_a[i];
}
int input1_contiguous_size = 1;
for (int i = contiguousIdx; i < SIZE_DIM_OUT; ++i) {
input1_contiguous_size *= ndim_b[i];
}
int output_contiguous_size = 1;
for (int i = contiguousIdx; i < SIZE_DIM_OUT; ++i) {
output_contiguous_size *= OUTPUT_DIMS[i];
}
// Initialize strides to iterate through data because of broadcasting
int stride_post0[contiguousIdx];
int stride_post1[contiguousIdx];
int stride_step0[contiguousIdx];
int stride_step1[contiguousIdx];
if (contiguousIdx > 0) {
stride_post0[contiguousIdx - 1] = 1;
stride_post1[contiguousIdx - 1] = 1;
for (int i = contiguousIdx - 2; i != -1; --i) {
stride_post0[i] = stride_post0[i + 1] * ndim_a[i + 1];
stride_post1[i] = stride_post1[i + 1] * ndim_b[i + 1];
}
for (int i = 0; i < contiguousIdx; ++i) {
stride_step0[i] = (ndim_a[i] == 1) ? 1 - stride_post0[i] : 1;
stride_step1[i] = (ndim_b[i] == 1) ? 1 - stride_post1[i] : 1;
for (int stack = 0; stack < NB_MAT; ++stack) {
for (int i = 0; i < OUTPUT_CONT_SIZE; ++i) {
const int in0_id = (INPUT1_CONT_SIZE != 1) ? i : 0;
const int in1_id = (INPUT2_CONT_SIZE != 1) ? i : 0;
const int out_id = i + stack * OUTPUT_CONT_SIZE;
const auto val1 = inputs1[in0_id + OFFSET_IN1[stack] * INPUT1_CONT_SIZE];
const auto val2 = inputs2[in1_id + OFFSET_IN2[stack] * INPUT2_CONT_SIZE];
Input_T val = 0;
switch (ELEM_OP) {
case Add:
val = val1 + val2;
break;
case Sub:
val = val1 - val2;
break;
case Mul:
val = val1 * val2;
break;
case Div:
val = val1 / val2;
break;
default:
val = val1;
break;
}
}
int offsetIn0 = 0;
int offsetIn1 = 0;
int offsetOut = 0;
int nbMatrices = 1;
for (int i = 0; i < contiguousIdx; ++i) {
nbMatrices *= OUTPUT_DIMS[i];
}
int dim = contiguousIdx - 1;
for (int stack = 0; stack < nbMatrices;) {
for (int i = 0; i < output_contiguous_size; ++i) {
int in0_id = (input0_contiguous_size != 1) ? i : 0;
int in1_id = (input1_contiguous_size != 1) ? i : 0;
switch (ELEM_OP) {
case Add:
outputs[i + offsetOut * output_contiguous_size] = inputs1[in0_id + offsetIn0 * input0_contiguous_size] + inputs2[in1_id + offsetIn1 * input1_contiguous_size];
break;
case Sub:
outputs[i + offsetOut * output_contiguous_size] = inputs1[in0_id + offsetIn0 * input0_contiguous_size] - inputs2[in1_id + offsetIn1 * input1_contiguous_size];
break;
case Mul:
outputs[i + offsetOut * output_contiguous_size] = inputs1[in0_id + offsetIn0 * input0_contiguous_size] * inputs2[in1_id + offsetIn1 * input1_contiguous_size];
break;
case Div:
outputs[i + offsetOut * output_contiguous_size] = inputs1[in0_id + offsetIn0 * input0_contiguous_size] / inputs2[in1_id + offsetIn1 * input1_contiguous_size];
break;
default:
val = inputs1[in0_id + offsetIn0 * input0_contiguous_size];
outputs[i + offsetOut * output_contiguous_size] = activation_forward_value<Output_T>(val, i, ACTIVATION, rescaling);
break;
}
}
if (++stack < nbMatrices) {
int tmp_stack = stack;
while (tmp_stack % OUTPUT_DIMS[dim] == 0) {
tmp_stack /= OUTPUT_DIMS[dim];
dim--;
}
offsetIn0 += stride_step0[dim];
offsetIn1 += stride_step1[dim];
++offsetOut;
dim = contiguousIdx - 1;
}
}
}
else
{
int32_t val = 0;
switch (ELEM_OP) {
case Add:
for (int i = 0; i < NB_ELTS; ++i) {
val = inputs1[i] + inputs2[i];
outputs[i] = activation_forward_value<Output_T>(val, i, ACTIVATION, rescaling);
}
break;
case Sub:
for (int i = 0; i < NB_ELTS; ++i) {
val = inputs1[i] - inputs2[i];
outputs[i] = activation_forward_value<Output_T>(val, i, ACTIVATION, rescaling);
}
break;
case Mul:
for (int i = 0; i < NB_ELTS; ++i) {
val = inputs1[i] * inputs2[i];
outputs[i] = activation_forward_value<Output_T>(val, i, ACTIVATION, rescaling);
}
break;
case Div:
for (int i = 0; i < NB_ELTS; ++i) {
val = inputs1[i] / inputs2[i];
outputs[i] = activation_forward_value<Output_T>(val, i, ACTIVATION, rescaling);
}
break;
default:
for (int i = 0; i < NB_ELTS; ++i) {
val = inputs1[i];
outputs[i] = activation_forward_value<Output_T>(val, i, ACTIVATION, rescaling);
}
break;
outputs[out_id] = activation_forward_value<Output_T>(val, out_id, ACTIVATION, rescaling);
}
}
}
......
......@@ -15,6 +15,80 @@ class ElemWise(ExportNodeCpp):
self.attributes["shift_value"] = 0
self.attributes["coef_value"] = 1
nbdims_out = len(self.attributes["out_dims"][0])
dims_a = self.attributes["in_dims"][0]
dims_b = self.attributes["in_dims"][1]
ndim_a = [0] * nbdims_out
ndim_b = [0] * nbdims_out
idx_a = nbdims_out - len(dims_a)
for i in range(nbdims_out):
ndim_a[i] = 1 if i < idx_a else dims_a[i - idx_a]
idx_b = nbdims_out - len(dims_b)
for i in range(nbdims_out):
ndim_b[i] = 1 if i < idx_b else dims_b[i - idx_b]
# Find highest equal dimension
contiguousIdx = nbdims_out - 1
for i in range(nbdims_out - 1, -1, -1):
if ndim_a[i] != ndim_b[i]:
break
contiguousIdx = i
# Compute the highest number of contiguous data
input0_contiguous_size = 1
input1_contiguous_size = 1
output_contiguous_size = 1
for i in range(contiguousIdx, nbdims_out):
input0_contiguous_size *= ndim_a[i]
input1_contiguous_size *= ndim_b[i]
output_contiguous_size *= self.attributes["out_dims"][0][i]
self.attributes["input1_cont_size"] = input0_contiguous_size
self.attributes["input2_cont_size"] = input1_contiguous_size
self.attributes["output_cont_size"] = output_contiguous_size
# Initialize strides for broadcasting
stride_post0 = [0] * contiguousIdx
stride_post1 = [0] * contiguousIdx
stride_step0 = [0] * contiguousIdx
stride_step1 = [0] * contiguousIdx
if contiguousIdx > 0:
stride_post0[contiguousIdx - 1] = 1
stride_post1[contiguousIdx - 1] = 1
for i in range(contiguousIdx - 2, -1, -1):
stride_post0[i] = stride_post0[i + 1] * ndim_a[i + 1]
stride_post1[i] = stride_post1[i + 1] * ndim_b[i + 1]
for i in range(contiguousIdx):
stride_step0[i] = 1 - stride_post0[i] if ndim_a[i] == 1 else 1
stride_step1[i] = 1 - stride_post1[i] if ndim_b[i] == 1 else 1
# Offset and matrix count
offsetIn0 = 0
offsetIn1 = 0
nbMatrices = 1
for i in range(contiguousIdx):
nbMatrices *= self.attributes["out_dims"][0][i]
self.attributes["offset_in1"] = [0]
self.attributes["offset_in2"] = [0]
for stack in range(1, nbMatrices):
dim = contiguousIdx - 1
tmp_stack = stack
while tmp_stack % self.attributes["out_dims"][0][dim] == 0:
tmp_stack //= self.attributes["out_dims"][0][dim]
dim -= 1
offsetIn0 += stride_step0[dim]
offsetIn1 += stride_step1[dim]
self.attributes["offset_in1"].append(offsetIn0)
self.attributes["offset_in2"].append(offsetIn1)
# Template for layer configutation file generation
self.config_template = str(ROOT / "templates" / "configuration" / "elemwise_config.jinja")
......
......@@ -6,19 +6,13 @@
{% include "./_def_io.jinja" %}
{% include "./_meminfo.jinja" %}
{# For layer configuration -#}
#define {{ name|upper }}_NB_ELTS {{ in_dims[0]|join('*') }}
#define {{ name|upper }}_NB_ELTS_B {{ in_dims[1]|join('*')}}
#define {{ name|upper }}_NB_MAT {{ offset_in1|length }}
#define {{ name|upper }}_INPUT1_CONT_SIZE {{ input1_cont_size }}
#define {{ name|upper }}_INPUT2_CONT_SIZE {{ input2_cont_size }}
#define {{ name|upper }}_OUTPUT_CONT_SIZE {{ output_cont_size }}
int {{name|upper}}_OUTPUT_DIMS[] = { {{ out_dims[0]|join(", ") }} };
int {{name|upper}}_INPUT_A_DIMS[] = { {{ in_dims[0]|join(", ") }} };
int {{name|upper}}_INPUT_B_DIMS[] = { {{ in_dims[1]|join(", ") }} };
#define {{name|upper}}_SIZE_DIM_IN_A {{in_dims[0]|length}}
#define {{name|upper}}_SIZE_DIM_IN_B {{in_dims[1]|length}}
#define {{name|upper}}_SIZE_DIM_OUT {{out_dims[0]|length}}
#define {{ name|upper }}_OUT_SIZE {{out_size[0]}}
#define {{name|upper }}_SIZE_DIM_OUT {{out_dims[0]|length}}
constexpr int {{name|upper}}_OFFSET_IN1[] = { {{ offset_in1|join(", ") }} };
constexpr int {{name|upper}}_OFFSET_IN2[] = { {{ offset_in2|join(", ") }} };
#define {{ name|upper }}_ACTIVATION {{ activation }}
#define {{ name|upper }}_ELEM_OP {{ elemwise_op }}
......
{% filter indent(width=4, first=False) %}
{% include "./_mem_offset.jinja" %}
elemwise_forward<{{name|upper}}_NB_ELTS,
elemwise_forward<{{name|upper}}_NB_MAT,
{{name|upper}}_ELEM_OP,
{{name|upper}}_INPUT_A_DIMS,
{{name|upper}}_INPUT_B_DIMS,
{{name|upper}}_OUTPUT_DIMS,
{{name|upper}}_SIZE_DIM_IN_A,
{{name|upper}}_SIZE_DIM_IN_B,
{{name|upper}}_SIZE_DIM_OUT,
{{name|upper}}_OUT_SIZE,
{{name|upper}}_INPUT1_CONT_SIZE,
{{name|upper}}_INPUT2_CONT_SIZE,
{{name|upper}}_OUTPUT_CONT_SIZE,
{{name|upper}}_OFFSET_IN1,
{{name|upper}}_OFFSET_IN2,
{{name|upper}}_ACTIVATION>
({{out_name[0]}}, {{name|upper}}_RESCALING, {{in_name[0]}}, {{in_name[1]}});
{% include "./_save_outputs.jinja" %}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment