diff --git a/aidge_export_cpp/kernels/concat.hpp b/aidge_export_cpp/kernels/concat.hpp index 2db8a0b039e31d44e27b38d37c9921ef38ca5a78..dde8c4fc3a9ce9eea5d4ae4cfad35c078f60450d 100644 --- a/aidge_export_cpp/kernels/concat.hpp +++ b/aidge_export_cpp/kernels/concat.hpp @@ -1,22 +1,39 @@ #ifndef __AIDGE_EXPORT_CPP_KERNELS_CONCAT__ #define __AIDGE_EXPORT_CPP_KERNELS_CONCAT__ -template<typename T, unsigned int NB_INPUTS> +template<int AXIS_SIZE_POST, + int AXIS_SIZE_PRE, + unsigned int NB_INPUTS, + typename T> __attribute__((always_inline)) inline static void concat_forward ( - const unsigned int axis, const T* const * __restrict inputs, const unsigned int* __restrict sizes, T* __restrict output) { - unsigned int offset = 0; + unsigned int total_concat_axis_size = 0; + for (unsigned int n = 0; n < NB_INPUTS; ++n) + total_concat_axis_size += sizes[n]; - for (unsigned int n = 0; n < NB_INPUTS; ++n) { - for (unsigned int i = 0; i < sizes[n]; ++i) { - output[offset + i] = inputs[n][i]; + for (int i = 0; i < AXIS_SIZE_PRE; ++i) { + // Loop over post-axis (e.g., dims after axis 1) + for (int j = 0; j < AXIS_SIZE_POST; ++j) { + unsigned int axis_offset = 0; + + // Loop over each input tensor + for (unsigned int n = 0; n < NB_INPUTS; ++n) { + for (unsigned int k = 0; k < sizes[n]; ++k) { + const int input_idx = i * sizes[n] * AXIS_SIZE_POST + k * AXIS_SIZE_POST + j; + + output[i * total_concat_axis_size * AXIS_SIZE_POST + (axis_offset + k) * AXIS_SIZE_POST + j] = + inputs[n][input_idx]; + } + + axis_offset += sizes[n]; // move along axis in output + } } - offset += sizes[n]; } + } #endif // __AIDGE_EXPORT_CPP_KERNELS_CONCAT__ \ No newline at end of file diff --git a/aidge_export_cpp/operators.py b/aidge_export_cpp/operators.py index 7c22cdb7af01392e0e3deb05bf4aecc16565e6a9..26ca62155401707573d9625ad91a9b63cb1b4d2b 100644 --- a/aidge_export_cpp/operators.py +++ b/aidge_export_cpp/operators.py @@ -337,28 +337,27 @@ class TransposeCPP(ExportNodeCpp): class SoftmaxCPP(ExportNodeCpp): def __init__(self, node, mem_info): super().__init__(node, mem_info) - self.attributes["axis"] = node.get_operator().attr.axis - assert self.node.get_nb_inputs() == 1, ( f"export softmax: nb_inputs == {self.node.get_nb_inputs()} not implemented" ) tensor = self.operator.get_input(0) nbDims = len(tensor.dims()) + axis = node.get_operator().attr.axis if node.get_operator().attr.axis >= 0 else node.get_operator().attr.axis + nbDims - assert self.attributes["axis"] < nbDims, ( + assert axis < nbDims, ( f"export softmax: attribute axis == {node.get_operator().attr.axis} should be less than {nbDims}" ) postAxisElems = 1 - for i in range(self.attributes["axis"] + 1, nbDims): + for i in range(axis + 1, nbDims): postAxisElems *= tensor.dims()[i] preAxisElems = 1 - for i in range(self.attributes["axis"]): + for i in range(axis): preAxisElems *= tensor.dims()[i] - self.attributes["axis_size"] = tensor.dims()[self.attributes["axis"]] + self.attributes["axis_size"] = tensor.dims()[axis] self.attributes["axis_size_post"] = postAxisElems self.attributes["axis_size_pre"] = preAxisElems @@ -395,7 +394,50 @@ class BatchNorm2DCPP(ExportNodeCpp): class Concat(ExportNodeCpp): def __init__(self, node, mem_info): super().__init__(node, mem_info) - self.attributes["axis"] = node.get_operator().attr.axis + assert self.node.get_nb_inputs() >= 1, ( + f"export softmax: nb_inputs == {self.node.get_nb_inputs()} not implemented" + ) + + inputIndex = 0 + + tensor = self.operator.get_input(0) + for idx, _ in enumerate(self.node.inputs()): + if self.operator.get_input(idx) is not None: + tensor = self.operator.get_input(idx) + nbDims = len(tensor.dims()) + axis = node.get_operator().attr.axis if node.get_operator().attr.axis >= 0 else node.get_operator().attr.axis + nbDims + + assert axis < nbDims, ( + f"export softmax: attribute axis == {axis} should be less than {nbDims}" + ) + + postAxisElems = 1 + for i in range(axis + 1, nbDims): + postAxisElems *= tensor.dims()[i] + + preAxisElems = 1 + for i in range(axis): + preAxisElems *= tensor.dims()[i] + + if (inputIndex == 0): + self.attributes["axis_size_post"] = postAxisElems + self.attributes["axis_size_pre"] = preAxisElems + + self.attributes["axis_size"] = [None] * self.attributes["nb_in"] + else: + assert self.attributes["axis_size_post"] == postAxisElems, ( + f"export concat: axis_size_post {self.attributes['axis_size_post']} != {postAxisElems}" + ) + assert self.attributes["axis_size_pre"] == preAxisElems, ( + f"export concat: axis_size_pre {self.attributes['axis_size_pre']} != {preAxisElems}" + ) + + self.attributes["axis_size"][idx] = tensor.dims()[axis] + else: + assert false, ( + f"export concat: input {idx} is None, not implemented") + + inputIndex += 1 self.config_template = str(ROOT / "templates" / "configuration" / "concat_config.jinja") self.forward_template = str(ROOT / "templates" / "kernel_forward" / "concat_forward.jinja") diff --git a/aidge_export_cpp/templates/configuration/concat_config.jinja b/aidge_export_cpp/templates/configuration/concat_config.jinja index 1a6637e9094c12fc04d47da4bddcee160f3c7a56..ea8246db9a315a371e0cacea5d45d07fa2b8f7e8 100644 --- a/aidge_export_cpp/templates/configuration/concat_config.jinja +++ b/aidge_export_cpp/templates/configuration/concat_config.jinja @@ -9,8 +9,10 @@ #define {{ name|upper }}_NB_INPUTS {{ nb_in }} #define {{ name|upper }}_AXIS {{ axis }} {%- for i in range(nb_in) %} -#define {{ name|upper }}_INPUT_{{i}}_SIZE {{ in_dims[i]|join('*') }} +#define {{ name|upper }}_INPUT_{{i}}_SIZE {{ axis_size[i] }} {%- endfor %} -#define {{ name|upper }}_OUTPUT_SIZE {{ out_dims[0]|join('*')}} + +#define {{ name|upper }}_AXIS_SIZE_POST {{ axis_size_post }} +#define {{ name|upper }}_AXIS_SIZE_PRE {{ axis_size_pre }} #endif /* {{ name|upper }}_LAYER_H */ diff --git a/aidge_export_cpp/templates/kernel_forward/concat_forward.jinja b/aidge_export_cpp/templates/kernel_forward/concat_forward.jinja index a2f48e9b5aa563924a346eabfaec67b0c30ef38f..7a77e904db6c18f338f93099f4f117c9285bf6fc 100644 --- a/aidge_export_cpp/templates/kernel_forward/concat_forward.jinja +++ b/aidge_export_cpp/templates/kernel_forward/concat_forward.jinja @@ -12,8 +12,10 @@ unsigned int {{ name|upper }}_SIZES[] = { {%- endfor -%} }; -concat_forward<float, {{ nb_in }}> ( - {{name|upper}}_AXIS, +concat_forward<{{ name|upper }}_AXIS_SIZE_POST, + {{ name|upper }}_AXIS_SIZE_PRE, + {{ nb_in }}, + float> ( {{ name|upper }}_INPUTS, {{ name|upper }}_SIZES, {{ out_name[0] }}); diff --git a/aidge_export_cpp/unit_tests/test_export.py b/aidge_export_cpp/unit_tests/test_export.py index 3d55f114f81c65ac14022eb4d9395aecb92285fc..607778d23deda862db73f5908fd1caa6ccc1d95b 100644 --- a/aidge_export_cpp/unit_tests/test_export.py +++ b/aidge_export_cpp/unit_tests/test_export.py @@ -185,6 +185,14 @@ class test_operator_export(unittest.TestCase): self.unit_test_export(model, "SoftmaxAxis2", [[1, 10, 3, 7]]) + def test_export_softmax_axis_negative(self): + print("SoftmaxAxisNegative") + model = aidge_core.sequential([ + aidge_core.Softmax(axis=-3, name="sf0") + ]) + + self.unit_test_export(model, "SoftmaxAxisNegative", [[1, 10, 3, 7]]) + def test_export_softmax_axis_0(self): print("SoftmaxAxis0") model = aidge_core.sequential([ @@ -376,6 +384,24 @@ class test_operator_export(unittest.TestCase): self.unit_test_export(model, "Concat", [[1, 5, 7]]) + def test_export_concat_axis_2(self): + print("ConcatAxis2") + model = aidge_core.sequential([ + aidge_core.Producer([1, 5, 7], name="producer"), + aidge_core.Concat(nb_inputs=2, axis=2, name="concat") + ]) + + self.unit_test_export(model, "ConcatAxis2", [[1, 5, 7]]) + + def test_export_concat_axis_negative(self): + print("ConcatAxisNegative") + model = aidge_core.sequential([ + aidge_core.Producer([1, 5, 7], name="producer"), + aidge_core.Concat(nb_inputs=2, axis=-2, name="concat") + ]) + + self.unit_test_export(model, "ConcatAxisNegative", [[1, 5, 7]]) + def test_export_conv2D(self): print("Conv2D") model = aidge_core.sequential([