diff --git a/aidge_export_cpp/kernels/transpose.hpp b/aidge_export_cpp/kernels/transpose.hpp index 082d738df6e18b668b2d533c3d04e814c57bd1e7..31c9e27869c5e2fde701f6700fd4964ea4cefd29 100644 --- a/aidge_export_cpp/kernels/transpose.hpp +++ b/aidge_export_cpp/kernels/transpose.hpp @@ -23,60 +23,59 @@ * Based on Tensor::copyTranspose from aidge.aidge_core * * @tparam T Data type of the tensor elements. + * @tparam NB_DIMS Number of dimensions of the input tensor. * @param[in] inputs Pointer to the input tensor data stored in contiguous memory. * @param[in] in_dims Array containing the size of each dimension of the input tensor. - * @param[in] nb_dims Number of dimensions of the input tensor. * @param[in] permute Array of unsigned integers specifying the desired permutation - * of dimensions. Each value should be in the range [0, nb_dims-1], + * of dimensions. Each value should be in the range [0, NB_DIMS-1], * defining the new order of dimensions for the output tensor. * @param[in] total_size Total number of elements in the input/output tensor. * @param[out] outputs Pointer to the pre-allocated memory for the transposed tensor. * Ensure this memory is appropriately sized to hold the transposed data. */ -template <typename T> +template <typename T,unsigned int NB_DIMS> void transpose_ND_forward(const T *__restrict inputs, const unsigned int *in_dims, - const unsigned int nb_dims, const unsigned int *permute, const unsigned int total_size, T *__restrict outputs) { // Compute strides for input tensor - unsigned int in_strides[nb_dims]; - in_strides[nb_dims - 1] = 1; - for (int i = nb_dims - 2; i >= 0; --i) + unsigned int in_strides[NB_DIMS]; + in_strides[NB_DIMS - 1] = 1; + for (int i = NB_DIMS - 2; i >= 0; --i) { in_strides[i] = in_strides[i + 1] * in_dims[i + 1]; } // Compute dimensions and strides for output tensor - unsigned int out_dims[nb_dims]; - unsigned int out_strides[nb_dims]; - out_strides[nb_dims - 1] = 1; - for (unsigned int i = 0; i < nb_dims; ++i) + unsigned int out_dims[NB_DIMS]; + unsigned int out_strides[NB_DIMS]; + out_strides[NB_DIMS - 1] = 1; + for (unsigned int i = 0; i < NB_DIMS; ++i) { out_dims[i] = in_dims[permute[i]]; } - for (int i = nb_dims - 2; i >= 0; --i) + for (int i = NB_DIMS - 2; i >= 0; --i) { out_strides[i] = out_strides[i + 1] * out_dims[i + 1]; } - unsigned int current_idx[nb_dims]; + unsigned int current_idx[NB_DIMS]; // Iterate over all elements in the input tensor for (unsigned int idx = 0; idx < total_size; ++idx) { unsigned int remaining = idx; - for (unsigned int i = 0; i < nb_dims; ++i) + for (unsigned int i = 0; i < NB_DIMS; ++i) { current_idx[i] = remaining / in_strides[i]; remaining = remaining % in_strides[i]; } unsigned int output_index = 0; - for (unsigned int i = 0; i < nb_dims; ++i) + for (unsigned int i = 0; i < NB_DIMS; ++i) { output_index += current_idx[permute[i]] * out_strides[i]; } diff --git a/aidge_export_cpp/templates/kernel_forward/transpose_ND_forward.jinja b/aidge_export_cpp/templates/kernel_forward/transpose_ND_forward.jinja index 8f39fbce7f4e5ceba36ad39929e527aaa7c927af..25af5bd9a3cdab4c91d5f2f09dae9144348729db 100644 --- a/aidge_export_cpp/templates/kernel_forward/transpose_ND_forward.jinja +++ b/aidge_export_cpp/templates/kernel_forward/transpose_ND_forward.jinja @@ -1 +1 @@ -transpose_ND_forward<{{in_cdtype[0]}}>({{in_name[0]}},{{name|upper}}_DIMS,{{name|upper}}_NB_DIMS,{{name|upper}}_PERMUTE,{{ out_name[0]|upper }}_SIZE,{{out_name[0]}}); \ No newline at end of file +transpose_ND_forward<{{in_cdtype[0]}},{{name|upper}}_NB_DIMS>({{in_name[0]}},{{name|upper}}_DIMS,{{name|upper}}_PERMUTE,{{ out_name[0]|upper }}_SIZE,{{out_name[0]}}); \ No newline at end of file