diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 785caaa0e8959ba34d438913a4c0e5bad3df0f86..5df59becdc41f12768935544a42aac24ffb3a333 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -431,7 +431,7 @@ public: * @brief Get the data format enum. * @return constexpr DataFormat */ - constexpr DataFormat dataFormat() const noexcept { return mDataFormat; } + const DataFormat& dataFormat() const noexcept { return mDataFormat; } /** * @brief Set the DataType of the Tensor and converts data @@ -462,13 +462,13 @@ public: * data is copy-transposed. */ void setDataFormat(const DataFormat df, bool copyTrans = true) { - if (!copyTrans || df == dataFormat()) { + if (!copyTrans || df == dataFormat() || df == DataFormat::Default || dataFormat() == DataFormat::Default) { mDataFormat = df; return; } - - const auto transpose = getDataFormatTranspose(dataFormat(), df); - + + const auto transpose = getPermutationMapping(dataFormat(), df); + if (mImpl) { copyTranspose(*this, transpose); } else { @@ -476,7 +476,7 @@ public: for (std::size_t i = 0; i < dims().size(); ++i) { newDims.push_back(dims()[transpose[i]]); } - + std::vector<std::size_t> newStrides(dims().size(), 1); for (size_t i = 0; i < dims().size(); ++i) { for (size_t j = i + 1; j < dims().size(); ++j) { @@ -486,9 +486,10 @@ public: mDims = std::move(newDims); mStrides = std::move(newStrides); } - + mDataFormat = df; } + /** * @brief Get the Impl object * @return constexpr const std::shared_ptr<TensorImpl>&