diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 5e4a817bb9ac22907c51841387d71c3809cd7e13..5c184e9613f5be5f9fd2bd49aa04bfdd83105c72 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -452,19 +452,23 @@ public: } /** - * @brief Set the DataType of the Tensor and converts data - * if the Tensor has already been initialized and copyCast is true. - * @param dt DataType - * @param copyCast If true (default), previous data is copy-casted. Otherwise - * previous data is lost. + * @brief Set the DataFormat of the Tensor and transpose data, only + * if the Tensor has already been initialized and copyTrans is true. + * In this case, a transposition occurs only if both previous format and + * new format are different from DataFormat::Default. + * @param df New DataFormat + * @param copyTrans If true (default), when both previous format and new + * format are different from DataFormat::Default, previous + * data is copy-transposed. */ void setDataFormat(const DataFormat df, bool copyTrans = true) { if (!copyTrans || df == dataFormat()) { mDataFormat = df; return; } - // Skip transformation if both formats are Default or NNCH - if ((df == DataFormat::Default && dataFormat() == DataFormat::Default) || df == DataFormat::NCHW && dataFormat() == DataFormat::NCHW) { + + if ((df == DataFormat::Default && dataFormat() == DataFormat::Default) || + (df == DataFormat::NCHW && dataFormat() == DataFormat::NCHW)) { mDataFormat = df; return; } @@ -475,24 +479,22 @@ public: copyTranspose(*this, transpose); } else { std::vector<DimSize_t> newDims; - newDims.reserve(nbDims()); - for (std::size_t i = 0; i < nbDims(); ++i) { + for (std::size_t i = 0; i < dims().size(); ++i) { newDims.push_back(dims()[transpose[i]]); } - mDims = std::move(newDims); - - std::vector<std::size_t> newStrides(nbDims(), 1); - for (size_t i = 0; i < nbDims(); ++i) { - for (size_t j = i + 1; j < nbDims(); ++j) { + + std::vector<std::size_t> newStrides(dims().size(), 1); + for (size_t i = 0; i < dims().size(); ++i) { + for (size_t j = i + 1; j < dims().size(); ++j) { newStrides[i] *= newDims[j]; } } + mDims = std::move(newDims); mStrides = std::move(newStrides); } mDataFormat = df; } - /** * @brief Get the Impl object * @return constexpr const std::shared_ptr<TensorImpl>&