Skip to content
Snippets Groups Projects
Commit 7f6768d0 authored by Wissam Boussella's avatar Wissam Boussella
Browse files

Refactor setDataFormat method to improve data type handling and transformation logic

parent 167021e1
No related branches found
No related tags found
No related merge requests found
......@@ -452,19 +452,44 @@ public:
}
/**
* @brief Set the DataFormat of the Tensor and transpose data, only
* if the Tensor has already been initialized and copyTrans is true.
* In this case, a transposition occurs only if both previous format and
* new format are different from DataFormat::Default.
* @param df New DataFormat
* @param copyTrans If true (default), when both previous format and new
* format are different from DataFormat::Default, previous
* data is copy-transposed.
* @brief Set the DataType of the Tensor and converts data
* if the Tensor has already been initialized and copyCast is true.
* @param dt DataType
* @param copyCast If true (default), previous data is copy-casted. Otherwise
* previous data is lost.
*/
void setDataFormat(const DataFormat df, bool copyTrans = true) {
if (mImpl && copyTrans && (dataFormat() != df) && df != DataFormat::Default && dataFormat() != DataFormat::Default) {
copyTranspose(*this, getDataFormatTranspose(dataFormat(), df));
if (!copyTrans || df == dataFormat()) {
mDataFormat = df;
return;
}
// Skip transformation if both formats are Default or NNCH
if ((df == DataFormat::Default && dataFormat() == DataFormat::Default) || df == DataFormat::NCHW && dataFormat() == DataFormat::NCHW) {
mDataFormat = df;
return;
}
const auto transpose = getDataFormatTranspose(dataFormat(), df);
if (mImpl) {
copyTranspose(*this, transpose);
} else {
std::vector<DimSize_t> newDims;
newDims.reserve(nbDims());
for (std::size_t i = 0; i < nbDims(); ++i) {
newDims.push_back(dims()[transpose[i]]);
}
mDims = std::move(newDims);
std::vector<std::size_t> newStrides(nbDims(), 1);
for (size_t i = 0; i < nbDims(); ++i) {
for (size_t j = i + 1; j < nbDims(); ++j) {
newStrides[i] *= newDims[j];
}
}
mStrides = std::move(newStrides);
}
mDataFormat = df;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment