Skip to content
Snippets Groups Projects
Commit b7b395f3 authored by Wissam Boussella's avatar Wissam Boussella
Browse files

now NB_DIMS is tezmplate's parameter

parent ac47d3e6
No related branches found
No related tags found
2 merge requests!39Update 0.2.1 -> 0.3.0,!30Add transpose function for 4D tensors and related templates
Pipeline #67214 passed
......@@ -23,60 +23,59 @@
* Based on Tensor::copyTranspose from aidge.aidge_core
*
* @tparam T Data type of the tensor elements.
* @tparam NB_DIMS Number of dimensions of the input tensor.
* @param[in] inputs Pointer to the input tensor data stored in contiguous memory.
* @param[in] in_dims Array containing the size of each dimension of the input tensor.
* @param[in] nb_dims Number of dimensions of the input tensor.
* @param[in] permute Array of unsigned integers specifying the desired permutation
* of dimensions. Each value should be in the range [0, nb_dims-1],
* of dimensions. Each value should be in the range [0, NB_DIMS-1],
* defining the new order of dimensions for the output tensor.
* @param[in] total_size Total number of elements in the input/output tensor.
* @param[out] outputs Pointer to the pre-allocated memory for the transposed tensor.
* Ensure this memory is appropriately sized to hold the transposed data.
*/
template <typename T>
template <typename T,unsigned int NB_DIMS>
void transpose_ND_forward(const T *__restrict inputs,
const unsigned int *in_dims,
const unsigned int nb_dims,
const unsigned int *permute,
const unsigned int total_size,
T *__restrict outputs)
{
// Compute strides for input tensor
unsigned int in_strides[nb_dims];
in_strides[nb_dims - 1] = 1;
for (int i = nb_dims - 2; i >= 0; --i)
unsigned int in_strides[NB_DIMS];
in_strides[NB_DIMS - 1] = 1;
for (int i = NB_DIMS - 2; i >= 0; --i)
{
in_strides[i] = in_strides[i + 1] * in_dims[i + 1];
}
// Compute dimensions and strides for output tensor
unsigned int out_dims[nb_dims];
unsigned int out_strides[nb_dims];
out_strides[nb_dims - 1] = 1;
for (unsigned int i = 0; i < nb_dims; ++i)
unsigned int out_dims[NB_DIMS];
unsigned int out_strides[NB_DIMS];
out_strides[NB_DIMS - 1] = 1;
for (unsigned int i = 0; i < NB_DIMS; ++i)
{
out_dims[i] = in_dims[permute[i]];
}
for (int i = nb_dims - 2; i >= 0; --i)
for (int i = NB_DIMS - 2; i >= 0; --i)
{
out_strides[i] = out_strides[i + 1] * out_dims[i + 1];
}
unsigned int current_idx[nb_dims];
unsigned int current_idx[NB_DIMS];
// Iterate over all elements in the input tensor
for (unsigned int idx = 0; idx < total_size; ++idx)
{
unsigned int remaining = idx;
for (unsigned int i = 0; i < nb_dims; ++i)
for (unsigned int i = 0; i < NB_DIMS; ++i)
{
current_idx[i] = remaining / in_strides[i];
remaining = remaining % in_strides[i];
}
unsigned int output_index = 0;
for (unsigned int i = 0; i < nb_dims; ++i)
for (unsigned int i = 0; i < NB_DIMS; ++i)
{
output_index += current_idx[permute[i]] * out_strides[i];
}
......
transpose_ND_forward<{{in_cdtype[0]}}>({{in_name[0]}},{{name|upper}}_DIMS,{{name|upper}}_NB_DIMS,{{name|upper}}_PERMUTE,{{ out_name[0]|upper }}_SIZE,{{out_name[0]}});
\ No newline at end of file
transpose_ND_forward<{{in_cdtype[0]}},{{name|upper}}_NB_DIMS>({{in_name[0]}},{{name|upper}}_DIMS,{{name|upper}}_PERMUTE,{{ out_name[0]|upper }}_SIZE,{{out_name[0]}});
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment