Skip to content
Snippets Groups Projects
Commit 6d0b593b authored by Maxence Naud's avatar Maxence Naud
Browse files

Fix: do not transpose data in Tensor if current or new format is Default

parent 8947fddd
No related branches found
No related tags found
3 merge requests!414Update version 0.5.1 -> 0.6.0,!408[Add] Dropout Operator,!355[Fix] Tensor::setDataFormat handling of DataFormat::Default
......@@ -431,7 +431,7 @@ public:
* @brief Get the data format enum.
* @return constexpr DataFormat
*/
constexpr DataFormat dataFormat() const noexcept { return mDataFormat; }
const DataFormat& dataFormat() const noexcept { return mDataFormat; }
/**
* @brief Set the DataType of the Tensor and converts data
......@@ -462,13 +462,13 @@ public:
* data is copy-transposed.
*/
void setDataFormat(const DataFormat df, bool copyTrans = true) {
if (!copyTrans || df == dataFormat()) {
if (!copyTrans || df == dataFormat() || df == DataFormat::Default || dataFormat() == DataFormat::Default) {
mDataFormat = df;
return;
}
const auto transpose = getDataFormatTranspose(dataFormat(), df);
const auto transpose = getPermutationMapping(dataFormat(), df);
if (mImpl) {
copyTranspose(*this, transpose);
} else {
......@@ -476,7 +476,7 @@ public:
for (std::size_t i = 0; i < dims().size(); ++i) {
newDims.push_back(dims()[transpose[i]]);
}
std::vector<std::size_t> newStrides(dims().size(), 1);
for (size_t i = 0; i < dims().size(); ++i) {
for (size_t j = i + 1; j < dims().size(); ++j) {
......@@ -486,9 +486,10 @@ public:
mDims = std::move(newDims);
mStrides = std::move(newStrides);
}
mDataFormat = df;
}
/**
* @brief Get the Impl object
* @return constexpr const std::shared_ptr<TensorImpl>&
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment