diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 7aa2ed52b95e11598a2975558212b00a85dac598..5e4a817bb9ac22907c51841387d71c3809cd7e13 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -452,19 +452,44 @@ public: } /** - * @brief Set the DataFormat of the Tensor and transpose data, only - * if the Tensor has already been initialized and copyTrans is true. - * In this case, a transposition occurs only if both previous format and - * new format are different from DataFormat::Default. - * @param df New DataFormat - * @param copyTrans If true (default), when both previous format and new - * format are different from DataFormat::Default, previous - * data is copy-transposed. + * @brief Set the DataType of the Tensor and converts data + * if the Tensor has already been initialized and copyCast is true. + * @param dt DataType + * @param copyCast If true (default), previous data is copy-casted. Otherwise + * previous data is lost. */ void setDataFormat(const DataFormat df, bool copyTrans = true) { - if (mImpl && copyTrans && (dataFormat() != df) && df != DataFormat::Default && dataFormat() != DataFormat::Default) { - copyTranspose(*this, getDataFormatTranspose(dataFormat(), df)); + if (!copyTrans || df == dataFormat()) { + mDataFormat = df; + return; + } + // Skip transformation if both formats are Default or NNCH + if ((df == DataFormat::Default && dataFormat() == DataFormat::Default) || df == DataFormat::NCHW && dataFormat() == DataFormat::NCHW) { + mDataFormat = df; + return; + } + + const auto transpose = getDataFormatTranspose(dataFormat(), df); + + if (mImpl) { + copyTranspose(*this, transpose); + } else { + std::vector<DimSize_t> newDims; + newDims.reserve(nbDims()); + for (std::size_t i = 0; i < nbDims(); ++i) { + newDims.push_back(dims()[transpose[i]]); + } + mDims = std::move(newDims); + + std::vector<std::size_t> newStrides(nbDims(), 1); + for (size_t i = 0; i < nbDims(); ++i) { + for (size_t j = i + 1; j < nbDims(); ++j) { + newStrides[i] *= newDims[j]; + } + } + mStrides = std::move(newStrides); } + mDataFormat = df; }