Skip to content
Snippets Groups Projects
Commit eb8bda50 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Fix] coord conversion functions in Tensor

parent a2248696
No related branches found
No related tags found
1 merge request!9Fuse bn
Pipeline #32234 failed
This commit is part of merge request !9. Comments created here will be created in the context of that merge request.
...@@ -580,13 +580,14 @@ class Tensor : public Data, ...@@ -580,13 +580,14 @@ class Tensor : public Data,
* @param flatIdx 1D index of the value considering a flatten tensor. * @param flatIdx 1D index of the value considering a flatten tensor.
* @return std::vector<DimSize_t> * @return std::vector<DimSize_t>
*/ */
std::vector<std::size_t> getCoord(std::size_t flatIdx){ std::vector<std::size_t> getCoord(std::size_t flatIdx) const {
std::vector<std::size_t> coordIdx = {}; std::vector<std::size_t> coordIdx = std::vector<std::size_t>(mDims.size());
std::size_t idx = flatIdx; std::size_t idx = flatIdx;
for (std::size_t d: mDims){ for (std::size_t i = mDims.size() - 1; i > 0; --i){
coordIdx.push_back(idx % d); coordIdx[i] = (idx % mDims[i]);
idx/=d; idx/=mDims[i];
} }
coordIdx[0] = idx % mDims[0];
return coordIdx; return coordIdx;
} }
...@@ -596,16 +597,17 @@ class Tensor : public Data, ...@@ -596,16 +597,17 @@ class Tensor : public Data,
* @param coordIdx Coordinate to an element in the tensor * @param coordIdx Coordinate to an element in the tensor
* @return DimSize_t * @return DimSize_t
*/ */
std::size_t getIdx(std::vector<std::size_t> coordIdx){ std::size_t getIdx(std::vector<std::size_t> coordIdx) const {
// std::size_t flatIdx = 0;
// std::size_t stride = 1;
std::size_t flatIdx = 0; std::size_t flatIdx = 0;
std::size_t stride = 1;
assert(coordIdx.size() == mDims.size() && "Coordinates does not match number of dimensions"); assert(coordIdx.size() == mDims.size() && "Coordinates does not match number of dimensions");
for(std::size_t i=0; i< mDims.size(); ++i){ std::size_t i = 0;
for(; i < mDims.size() - 1; ++i){
assert(coordIdx[i] < mDims[i] && "Coordinates dimensions does not fit the dimensions of the tensor"); assert(coordIdx[i] < mDims[i] && "Coordinates dimensions does not fit the dimensions of the tensor");
flatIdx += (coordIdx[i] * stride); flatIdx = (flatIdx + coordIdx[i]) * mDims[i + 1];
stride *= mDims[i];
} }
return flatIdx; return flatIdx + coordIdx[i];
} }
private: private:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment