From 6d0b593b705d66d9eb449949ce3fff7ae6821c9c Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Tue, 25 Feb 2025 15:46:42 +0000 Subject: [PATCH] Fix: do not transpose data in Tensor if current or new format is Default --- include/aidge/data/Tensor.hpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 785caaa0e..5df59becd 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -431,7 +431,7 @@ public: * @brief Get the data format enum. * @return constexpr DataFormat */ - constexpr DataFormat dataFormat() const noexcept { return mDataFormat; } + const DataFormat& dataFormat() const noexcept { return mDataFormat; } /** * @brief Set the DataType of the Tensor and converts data @@ -462,13 +462,13 @@ public: * data is copy-transposed. */ void setDataFormat(const DataFormat df, bool copyTrans = true) { - if (!copyTrans || df == dataFormat()) { + if (!copyTrans || df == dataFormat() || df == DataFormat::Default || dataFormat() == DataFormat::Default) { mDataFormat = df; return; } - - const auto transpose = getDataFormatTranspose(dataFormat(), df); - + + const auto transpose = getPermutationMapping(dataFormat(), df); + if (mImpl) { copyTranspose(*this, transpose); } else { @@ -476,7 +476,7 @@ public: for (std::size_t i = 0; i < dims().size(); ++i) { newDims.push_back(dims()[transpose[i]]); } - + std::vector<std::size_t> newStrides(dims().size(), 1); for (size_t i = 0; i < dims().size(); ++i) { for (size_t j = i + 1; j < dims().size(); ++j) { @@ -486,9 +486,10 @@ public: mDims = std::move(newDims); mStrides = std::move(newStrides); } - + mDataFormat = df; } + /** * @brief Get the Impl object * @return constexpr const std::shared_ptr<TensorImpl>& -- GitLab