diff --git a/aidge_export_cpp/kernels/concat.hpp b/aidge_export_cpp/kernels/concat.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2db8a0b039e31d44e27b38d37c9921ef38ca5a78 --- /dev/null +++ b/aidge_export_cpp/kernels/concat.hpp @@ -0,0 +1,22 @@ +#ifndef __AIDGE_EXPORT_CPP_KERNELS_CONCAT__ +#define __AIDGE_EXPORT_CPP_KERNELS_CONCAT__ + +template<typename T, unsigned int NB_INPUTS> +__attribute__((always_inline)) inline static +void concat_forward ( + const unsigned int axis, + const T* const * __restrict inputs, + const unsigned int* __restrict sizes, + T* __restrict output) +{ + unsigned int offset = 0; + + for (unsigned int n = 0; n < NB_INPUTS; ++n) { + for (unsigned int i = 0; i < sizes[n]; ++i) { + output[offset + i] = inputs[n][i]; + } + offset += sizes[n]; + } +} + +#endif // __AIDGE_EXPORT_CPP_KERNELS_CONCAT__ \ No newline at end of file diff --git a/aidge_export_cpp/operators.py b/aidge_export_cpp/operators.py index b7e54727459075af5494769e3047084e0f662c63..f49c501c9ea3f4bc38a35837d65e36bf8c394b5b 100644 --- a/aidge_export_cpp/operators.py +++ b/aidge_export_cpp/operators.py @@ -338,3 +338,19 @@ class BatchNorm2DCPP(ExportNodeCpp): str(ROOT / "kernels" / "rescaling.hpp") ] +@ExportLibCpp.register("Concat", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) +class Concat(ExportNodeCpp): + def __init__(self, node, mem_info): + super().__init__(node, mem_info) + + print(node.get_operator()) + print(dir(node.get_operator())) + self.attributes["nb_in"] = node.get_operator().nb_inputs() + self.attributes["axis"] = node.get_operator().attr.axis + + self.config_template = str(ROOT / "templates" / "configuration" / "concat.jinja") + self.forward_template = str(ROOT / "templates" / "forward_call" / "concat.jinja") + self.include_list = [] + self.kernels_to_copy = [ + str(ROOT / "kernels" / "concat.hpp"), + ] \ No newline at end of file diff --git a/aidge_export_cpp/templates/configuration/concat.jinja b/aidge_export_cpp/templates/configuration/concat.jinja new file mode 100644 index 0000000000000000000000000000000000000000..8aa63156a2d890bbfb6f0c7ddce700917ccef83b --- /dev/null +++ b/aidge_export_cpp/templates/configuration/concat.jinja @@ -0,0 +1,15 @@ +{#- For name header -#} +#ifndef {{ name|upper }}_LAYER_H +#define {{ name|upper }}_LAYER_H + +{% include "./_meminfo.jinja" %} + +// Attributes +#define {{ name|upper }}_NB_INPUTS {{ nb_in }} +#define {{ name|upper }}_AXIS {{ axis }} +{%- for i in range(nb_in) %} +#define {{ name|upper }}_INPUT_{{i}}_SIZE {{ in_chan[i] * in_height[i] * in_width[i] }} +{%- endfor %} +#define {{ name|upper }}_OUTPUT_SIZE {{ out_chan[0] * out_height[0] * out_width[0] }} + +#endif /* {{ name|upper }}_LAYER_H */ diff --git a/aidge_export_cpp/templates/kernel_forward/concat.jinja b/aidge_export_cpp/templates/kernel_forward/concat.jinja new file mode 100644 index 0000000000000000000000000000000000000000..46fe87e43c51c672b3d74bfbaabbb21aaac12ee7 --- /dev/null +++ b/aidge_export_cpp/templates/kernel_forward/concat.jinja @@ -0,0 +1,20 @@ +{% filter indent(width=4, first=False) %} +{% include "./_mem_offset.jinja" %} +float* {{ name|upper }}_INPUTS[] = { + {%- for i in range(nb_in) -%} + {{ in_name[i] }}{{ ", " if not loop.last else "" }} + {%- endfor -%} +}; + +unsigned int {{ name|upper }}_SIZES[] = { + {%- for i in range(nb_in) -%} + {{ name|upper }}_INPUT_{{i}}_SIZE{{ ", " if not loop.last else "" }} + {%- endfor -%} +}; + +aidge_concat<float, {{ nb_in }}> ( + {{name|upper}}_AXIS, + {{ name|upper }}_INPUTS, + {{ name|upper }}_SIZES, + {{ out_name[0] }}); + {% endfilter %}