Skip to content
Snippets Groups Projects
Commit a9fbc969 authored by Wissam Boussella's avatar Wissam Boussella Committed by Olivier BICHLER
Browse files

setDataFormat fix, working without impl

parent d2689b32
No related branches found
No related tags found
1 merge request!314[Feat] ForwardDims Conv for NCHW and NHWC
...@@ -452,19 +452,23 @@ public: ...@@ -452,19 +452,23 @@ public:
} }
/** /**
* @brief Set the DataType of the Tensor and converts data * @brief Set the DataFormat of the Tensor and transpose data, only
* if the Tensor has already been initialized and copyCast is true. * if the Tensor has already been initialized and copyTrans is true.
* @param dt DataType * In this case, a transposition occurs only if both previous format and
* @param copyCast If true (default), previous data is copy-casted. Otherwise * new format are different from DataFormat::Default.
* previous data is lost. * @param df New DataFormat
* @param copyTrans If true (default), when both previous format and new
* format are different from DataFormat::Default, previous
* data is copy-transposed.
*/ */
void setDataFormat(const DataFormat df, bool copyTrans = true) { void setDataFormat(const DataFormat df, bool copyTrans = true) {
if (!copyTrans || df == dataFormat()) { if (!copyTrans || df == dataFormat()) {
mDataFormat = df; mDataFormat = df;
return; return;
} }
// Skip transformation if both formats are Default or NNCH
if ((df == DataFormat::Default && dataFormat() == DataFormat::Default) || df == DataFormat::NCHW && dataFormat() == DataFormat::NCHW) { if ((df == DataFormat::Default && dataFormat() == DataFormat::Default) ||
(df == DataFormat::NCHW && dataFormat() == DataFormat::NCHW)) {
mDataFormat = df; mDataFormat = df;
return; return;
} }
...@@ -475,24 +479,22 @@ public: ...@@ -475,24 +479,22 @@ public:
copyTranspose(*this, transpose); copyTranspose(*this, transpose);
} else { } else {
std::vector<DimSize_t> newDims; std::vector<DimSize_t> newDims;
newDims.reserve(nbDims()); for (std::size_t i = 0; i < dims().size(); ++i) {
for (std::size_t i = 0; i < nbDims(); ++i) {
newDims.push_back(dims()[transpose[i]]); newDims.push_back(dims()[transpose[i]]);
} }
mDims = std::move(newDims);
std::vector<std::size_t> newStrides(dims().size(), 1);
std::vector<std::size_t> newStrides(nbDims(), 1); for (size_t i = 0; i < dims().size(); ++i) {
for (size_t i = 0; i < nbDims(); ++i) { for (size_t j = i + 1; j < dims().size(); ++j) {
for (size_t j = i + 1; j < nbDims(); ++j) {
newStrides[i] *= newDims[j]; newStrides[i] *= newDims[j];
} }
} }
mDims = std::move(newDims);
mStrides = std::move(newStrides); mStrides = std::move(newStrides);
} }
mDataFormat = df; mDataFormat = df;
} }
/** /**
* @brief Get the Impl object * @brief Get the Impl object
* @return constexpr const std::shared_ptr<TensorImpl>& * @return constexpr const std::shared_ptr<TensorImpl>&
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment