From d2689b32cd2c188b953c5c9be26f3ce0d6966569 Mon Sep 17 00:00:00 2001
From: Wissam Boussella <wissam.boussella@cea.fr>
Date: Thu, 20 Feb 2025 10:44:00 +0100
Subject: [PATCH] Refactor setDataFormat method to improve data type handling
 and transformation logic

---
 include/aidge/data/Tensor.hpp | 45 +++++++++++++++++++++++++++--------
 1 file changed, 35 insertions(+), 10 deletions(-)

diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp
index 7aa2ed52b..5e4a817bb 100644
--- a/include/aidge/data/Tensor.hpp
+++ b/include/aidge/data/Tensor.hpp
@@ -452,19 +452,44 @@ public:
     }
 
     /**
-     * @brief Set the DataFormat of the Tensor and transpose data, only
-     * if the Tensor has already been initialized and copyTrans is true.
-     * In this case, a transposition occurs only if both previous format and
-     * new format are different from DataFormat::Default.
-     * @param df New DataFormat
-     * @param copyTrans If true (default), when both previous format and new
-     *                  format are different from DataFormat::Default, previous
-     *                  data is copy-transposed.
+     * @brief Set the DataType of the Tensor and converts data
+     * if the Tensor has already been initialized and copyCast is true.
+     * @param dt DataType
+     * @param copyCast If true (default), previous data is copy-casted. Otherwise
+     * previous data is lost.
      */
     void setDataFormat(const DataFormat df, bool copyTrans = true) {
-        if (mImpl && copyTrans && (dataFormat() != df) && df != DataFormat::Default && dataFormat() != DataFormat::Default) {
-            copyTranspose(*this, getDataFormatTranspose(dataFormat(), df));
+        if (!copyTrans || df == dataFormat()) {
+            mDataFormat = df;
+            return;
+        }
+        // Skip transformation if both formats are Default or NNCH
+        if ((df == DataFormat::Default && dataFormat() == DataFormat::Default) || df == DataFormat::NCHW && dataFormat() == DataFormat::NCHW) {
+            mDataFormat = df;
+            return;
+        }
+    
+        const auto transpose = getDataFormatTranspose(dataFormat(), df);
+        
+        if (mImpl) {
+            copyTranspose(*this, transpose);
+        } else {
+            std::vector<DimSize_t> newDims;
+            newDims.reserve(nbDims());
+            for (std::size_t i = 0; i < nbDims(); ++i) {
+                newDims.push_back(dims()[transpose[i]]);
+            }
+            mDims = std::move(newDims);
+
+            std::vector<std::size_t> newStrides(nbDims(), 1);
+            for (size_t i = 0; i < nbDims(); ++i) {
+                for (size_t j = i + 1; j < nbDims(); ++j) {
+                    newStrides[i] *= newDims[j];
+                }
+            }
+            mStrides = std::move(newStrides);
         }
+    
         mDataFormat = df;
     }
 
-- 
GitLab