Skip to content
Snippets Groups Projects
Commit c7ce1826 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

More static concat, works with any type

parent 0be40067
No related branches found
No related tags found
2 merge requests!710.4.0,!59Continuous improvement of export_cpp
Pipeline #77457 passed with warnings
......@@ -3,18 +3,15 @@
template<int AXIS_SIZE_POST,
int AXIS_SIZE_PRE,
const int AXIS_SIZE[],
int TOTAL_AXIS_SIZE,
unsigned int NB_INPUTS,
typename T>
__attribute__((always_inline)) inline static
void concat_forward (
const T* const * __restrict inputs,
const unsigned int* __restrict sizes,
T* __restrict output)
{
unsigned int total_concat_axis_size = 0;
for (unsigned int n = 0; n < NB_INPUTS; ++n)
total_concat_axis_size += sizes[n];
for (int i = 0; i < AXIS_SIZE_PRE; ++i) {
// Loop over post-axis (e.g., dims after axis 1)
for (int j = 0; j < AXIS_SIZE_POST; ++j) {
......@@ -22,18 +19,17 @@ void concat_forward (
// Loop over each input tensor
for (unsigned int n = 0; n < NB_INPUTS; ++n) {
for (unsigned int k = 0; k < sizes[n]; ++k) {
const int input_idx = i * sizes[n] * AXIS_SIZE_POST + k * AXIS_SIZE_POST + j;
for (unsigned int k = 0; k < AXIS_SIZE[n]; ++k) {
const int input_idx = i * AXIS_SIZE[n] * AXIS_SIZE_POST + k * AXIS_SIZE_POST + j;
output[i * total_concat_axis_size * AXIS_SIZE_POST + (axis_offset + k) * AXIS_SIZE_POST + j] =
output[i * TOTAL_AXIS_SIZE * AXIS_SIZE_POST + (axis_offset + k) * AXIS_SIZE_POST + j] =
inputs[n][input_idx];
}
axis_offset += sizes[n]; // move along axis in output
axis_offset += AXIS_SIZE[n]; // move along axis in output
}
}
}
}
#endif // __AIDGE_EXPORT_CPP_KERNELS_CONCAT__
\ No newline at end of file
......@@ -8,9 +8,8 @@
// Attributes
#define {{ name|upper }}_NB_INPUTS {{ nb_in }}
#define {{ name|upper }}_AXIS {{ axis }}
{%- for i in range(nb_in) %}
#define {{ name|upper }}_INPUT_{{i}}_SIZE {{ axis_size[i] }}
{%- endfor %}
constexpr int {{name|upper}}_AXIS_SIZE[] = { {{ axis_size|join(", ") }} };
#define {{name|upper}}_TOTAL_AXIS_SIZE ({{ axis_size|join('+') }})
#define {{ name|upper }}_AXIS_SIZE_POST {{ axis_size_post }}
#define {{ name|upper }}_AXIS_SIZE_PRE {{ axis_size_pre }}
......
{% filter indent(width=4, first=False) %}
{% include "./_mem_offset.jinja" %}
const float* {{ name|upper }}_INPUTS[] = {
{%- for i in range(nb_in) -%}
{{ in_name[i] }}{{ ", " if not loop.last else "" }}
{%- endfor -%}
};
unsigned int {{ name|upper }}_SIZES[] = {
{%- for i in range(nb_in) -%}
{{ name|upper }}_INPUT_{{i}}_SIZE{{ ", " if not loop.last else "" }}
{%- endfor -%}
};
const {{ out_cdtype[0] }}* {{ name }}_inputs[] = { {{ in_name|join(", ") }} };
concat_forward<{{ name|upper }}_AXIS_SIZE_POST,
{{ name|upper }}_AXIS_SIZE_PRE,
{{ nb_in }}, float>
({{ name|upper }}_INPUTS,
{{ name|upper }}_SIZES,
{{ name|upper }}_AXIS_SIZE,
{{ name|upper }}_TOTAL_AXIS_SIZE,
{{ nb_in }}>
({{ name }}_inputs,
{{ out_name[0] }});
{%- endfilter %}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment