From eb8bda502309f66974b15368a200f61e9cb043c8 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Fri, 29 Sep 2023 09:21:16 +0000
Subject: [PATCH] [Fix] coord conversion functions in Tensor

---
 include/aidge/data/Tensor.hpp | 24 +++++++++++++-----------
 1 file changed, 13 insertions(+), 11 deletions(-)

diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp
index 0a67d73a9..7422a52eb 100644
--- a/include/aidge/data/Tensor.hpp
+++ b/include/aidge/data/Tensor.hpp
@@ -580,13 +580,14 @@ class Tensor : public Data,
      * @param flatIdx 1D index of the value considering a flatten tensor.
      * @return std::vector<DimSize_t>
      */
-    std::vector<std::size_t> getCoord(std::size_t flatIdx){
-        std::vector<std::size_t> coordIdx = {};
+    std::vector<std::size_t> getCoord(std::size_t flatIdx) const {
+        std::vector<std::size_t> coordIdx = std::vector<std::size_t>(mDims.size());
         std::size_t idx = flatIdx;
-        for (std::size_t d: mDims){
-            coordIdx.push_back(idx % d);
-            idx/=d;
+        for (std::size_t i = mDims.size() - 1; i > 0; --i){
+            coordIdx[i] = (idx % mDims[i]);
+            idx/=mDims[i];
         }
+        coordIdx[0] = idx % mDims[0];
         return coordIdx;
     }
 
@@ -596,16 +597,17 @@ class Tensor : public Data,
      * @param coordIdx Coordinate to an element in the tensor
      * @return DimSize_t
      */
-    std::size_t getIdx(std::vector<std::size_t> coordIdx){
+    std::size_t getIdx(std::vector<std::size_t> coordIdx) const {
+        // std::size_t flatIdx = 0;
+        // std::size_t stride = 1;
         std::size_t flatIdx = 0;
-        std::size_t stride = 1;
         assert(coordIdx.size() == mDims.size() && "Coordinates does not match number of dimensions");
-        for(std::size_t i=0; i< mDims.size(); ++i){
+        std::size_t i = 0;
+        for(; i < mDims.size() - 1; ++i){
             assert(coordIdx[i] < mDims[i] && "Coordinates dimensions does not fit the dimensions of the tensor");
-            flatIdx += (coordIdx[i] * stride);
-            stride *= mDims[i];
+            flatIdx = (flatIdx + coordIdx[i]) * mDims[i + 1];
         }
-        return flatIdx;
+        return flatIdx + coordIdx[i];
     }
 
 private:
-- 
GitLab