From 6d0b593b705d66d9eb449949ce3fff7ae6821c9c Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Tue, 25 Feb 2025 15:46:42 +0000
Subject: [PATCH] Fix: do not transpose data in Tensor if current or new format
 is Default

---
 include/aidge/data/Tensor.hpp | 15 ++++++++-------
 1 file changed, 8 insertions(+), 7 deletions(-)

diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp
index 785caaa0e..5df59becd 100644
--- a/include/aidge/data/Tensor.hpp
+++ b/include/aidge/data/Tensor.hpp
@@ -431,7 +431,7 @@ public:
      * @brief Get the data format enum.
      * @return constexpr DataFormat
      */
-    constexpr DataFormat dataFormat() const noexcept { return mDataFormat; }
+    const DataFormat& dataFormat() const noexcept { return mDataFormat; }
 
     /**
      * @brief Set the DataType of the Tensor and converts data
@@ -462,13 +462,13 @@ public:
      *                  data is copy-transposed.
      */
     void setDataFormat(const DataFormat df, bool copyTrans = true) {
-        if (!copyTrans || df == dataFormat()) {
+        if (!copyTrans || df == dataFormat() || df == DataFormat::Default || dataFormat() == DataFormat::Default) {
             mDataFormat = df;
             return;
         }
-    
-        const auto transpose = getDataFormatTranspose(dataFormat(), df);
-        
+
+        const auto transpose = getPermutationMapping(dataFormat(), df);
+
         if (mImpl) {
             copyTranspose(*this, transpose);
         } else {
@@ -476,7 +476,7 @@ public:
             for (std::size_t i = 0; i < dims().size(); ++i) {
                 newDims.push_back(dims()[transpose[i]]);
             }
-    
+
             std::vector<std::size_t> newStrides(dims().size(), 1);
             for (size_t i = 0; i < dims().size(); ++i) {
                 for (size_t j = i + 1; j < dims().size(); ++j) {
@@ -486,9 +486,10 @@ public:
             mDims = std::move(newDims);
             mStrides = std::move(newStrides);
         }
-    
+
         mDataFormat = df;
     }
+
     /**
      * @brief Get the Impl object
      * @return constexpr const std::shared_ptr<TensorImpl>&
-- 
GitLab