Skip to content
Snippets Groups Projects
Commit d2689b32 authored by Wissam Boussella's avatar Wissam Boussella Committed by Olivier BICHLER
Browse files

Refactor setDataFormat method to improve data type handling and transformation logic

parent a9acb2f8
No related branches found
No related tags found
1 merge request!314[Feat] ForwardDims Conv for NCHW and NHWC
......@@ -452,19 +452,44 @@ public:
}
/**
* @brief Set the DataFormat of the Tensor and transpose data, only
* if the Tensor has already been initialized and copyTrans is true.
* In this case, a transposition occurs only if both previous format and
* new format are different from DataFormat::Default.
* @param df New DataFormat
* @param copyTrans If true (default), when both previous format and new
* format are different from DataFormat::Default, previous
* data is copy-transposed.
* @brief Set the DataType of the Tensor and converts data
* if the Tensor has already been initialized and copyCast is true.
* @param dt DataType
* @param copyCast If true (default), previous data is copy-casted. Otherwise
* previous data is lost.
*/
void setDataFormat(const DataFormat df, bool copyTrans = true) {
if (mImpl && copyTrans && (dataFormat() != df) && df != DataFormat::Default && dataFormat() != DataFormat::Default) {
copyTranspose(*this, getDataFormatTranspose(dataFormat(), df));
if (!copyTrans || df == dataFormat()) {
mDataFormat = df;
return;
}
// Skip transformation if both formats are Default or NNCH
if ((df == DataFormat::Default && dataFormat() == DataFormat::Default) || df == DataFormat::NCHW && dataFormat() == DataFormat::NCHW) {
mDataFormat = df;
return;
}
const auto transpose = getDataFormatTranspose(dataFormat(), df);
if (mImpl) {
copyTranspose(*this, transpose);
} else {
std::vector<DimSize_t> newDims;
newDims.reserve(nbDims());
for (std::size_t i = 0; i < nbDims(); ++i) {
newDims.push_back(dims()[transpose[i]]);
}
mDims = std::move(newDims);
std::vector<std::size_t> newStrides(nbDims(), 1);
for (size_t i = 0; i < nbDims(); ++i) {
for (size_t j = i + 1; j < nbDims(); ++j) {
newStrides[i] *= newDims[j];
}
}
mStrides = std::move(newStrides);
}
mDataFormat = df;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment