Skip to content
Snippets Groups Projects
Commit 7f5184ff authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

fix transpose operator

parent f4c43da9
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!20Vit operators
......@@ -33,6 +33,14 @@ void TransposeImpl_cpu_forward_kernel( const typename Transpose_Op<DIM>::Attrs&
totalElements *= dimSize;
}
std::vector<std::size_t> outStrides(DIM, 1);
for (size_t i = 0; i < DIM; ++i) {
for (size_t j = i+1; j < DIM; ++j)
{
outStrides[i] *= outputDims[j];
}
}
std::vector<size_t> indices(outputDims.size(), 0);
for (size_t i = 0; i < totalElements; ++i) {
size_t idx = 0;
......@@ -42,20 +50,15 @@ void TransposeImpl_cpu_forward_kernel( const typename Transpose_Op<DIM>::Attrs&
permutedIndices[j] = indices[std::get<0>(attrs)[j]];
}
// Compute the position of the next element to copy from input
for (size_t j = 0; j < DIM; ++j) {
size_t currsize = 1;
for(size_t k=j+1; k< DIM; ++k)
currsize*= inputDims[k];
idx += permutedIndices[j] * currsize;
for (int j = DIM -1; j >=0; --j) {
idx += permutedIndices[j] * outStrides[j];
}
// Copy the value in output
output[i] = input[idx];
output[idx] = input[i];
// Update indices for the next iteration
for (int j = DIM - 1; j >= 0; --j) {
if (indices[j] < outputDims[j] - 1) {
if (indices[j] < inputDims[j] - 1) {
indices[j]++;
break;
} else {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment