From 7f5184ff5da19d72cd165b32217c64bc061609ac Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Tue, 23 Jan 2024 16:01:44 +0100 Subject: [PATCH] fix transpose operator --- .../TransposeImpl_forward_kernels.hpp | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/include/aidge/backend/cpu/operator/TransposeImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/TransposeImpl_forward_kernels.hpp index 307b6d99..9fd5e5b5 100644 --- a/include/aidge/backend/cpu/operator/TransposeImpl_forward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/TransposeImpl_forward_kernels.hpp @@ -33,6 +33,14 @@ void TransposeImpl_cpu_forward_kernel( const typename Transpose_Op<DIM>::Attrs& totalElements *= dimSize; } + std::vector<std::size_t> outStrides(DIM, 1); + for (size_t i = 0; i < DIM; ++i) { + for (size_t j = i+1; j < DIM; ++j) + { + outStrides[i] *= outputDims[j]; + } + } + std::vector<size_t> indices(outputDims.size(), 0); for (size_t i = 0; i < totalElements; ++i) { size_t idx = 0; @@ -42,20 +50,15 @@ void TransposeImpl_cpu_forward_kernel( const typename Transpose_Op<DIM>::Attrs& permutedIndices[j] = indices[std::get<0>(attrs)[j]]; } - // Compute the position of the next element to copy from input - for (size_t j = 0; j < DIM; ++j) { - size_t currsize = 1; - for(size_t k=j+1; k< DIM; ++k) - currsize*= inputDims[k]; - idx += permutedIndices[j] * currsize; + for (int j = DIM -1; j >=0; --j) { + idx += permutedIndices[j] * outStrides[j]; } - // Copy the value in output - output[i] = input[idx]; + output[idx] = input[i]; // Update indices for the next iteration for (int j = DIM - 1; j >= 0; --j) { - if (indices[j] < outputDims[j] - 1) { + if (indices[j] < inputDims[j] - 1) { indices[j]++; break; } else { -- GitLab