Skip to content
Snippets Groups Projects
Commit eb8bda50 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Fix] coord conversion functions in Tensor

parent a2248696
No related branches found
No related tags found
1 merge request!9Fuse bn
Pipeline #32234 failed
......@@ -580,13 +580,14 @@ class Tensor : public Data,
* @param flatIdx 1D index of the value considering a flatten tensor.
* @return std::vector<DimSize_t>
*/
std::vector<std::size_t> getCoord(std::size_t flatIdx){
std::vector<std::size_t> coordIdx = {};
std::vector<std::size_t> getCoord(std::size_t flatIdx) const {
std::vector<std::size_t> coordIdx = std::vector<std::size_t>(mDims.size());
std::size_t idx = flatIdx;
for (std::size_t d: mDims){
coordIdx.push_back(idx % d);
idx/=d;
for (std::size_t i = mDims.size() - 1; i > 0; --i){
coordIdx[i] = (idx % mDims[i]);
idx/=mDims[i];
}
coordIdx[0] = idx % mDims[0];
return coordIdx;
}
......@@ -596,16 +597,17 @@ class Tensor : public Data,
* @param coordIdx Coordinate to an element in the tensor
* @return DimSize_t
*/
std::size_t getIdx(std::vector<std::size_t> coordIdx){
std::size_t getIdx(std::vector<std::size_t> coordIdx) const {
// std::size_t flatIdx = 0;
// std::size_t stride = 1;
std::size_t flatIdx = 0;
std::size_t stride = 1;
assert(coordIdx.size() == mDims.size() && "Coordinates does not match number of dimensions");
for(std::size_t i=0; i< mDims.size(); ++i){
std::size_t i = 0;
for(; i < mDims.size() - 1; ++i){
assert(coordIdx[i] < mDims[i] && "Coordinates dimensions does not fit the dimensions of the tensor");
flatIdx += (coordIdx[i] * stride);
stride *= mDims[i];
flatIdx = (flatIdx + coordIdx[i]) * mDims[i + 1];
}
return flatIdx;
return flatIdx + coordIdx[i];
}
private:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment