Skip to content
Snippets Groups Projects
Commit 6d0b593b authored by Maxence Naud's avatar Maxence Naud
Browse files

Fix: do not transpose data in Tensor if current or new format is Default

parent 8947fddd
No related branches found
No related tags found
3 merge requests!414Update version 0.5.1 -> 0.6.0,!408[Add] Dropout Operator,!355[Fix] Tensor::setDataFormat handling of DataFormat::Default
...@@ -431,7 +431,7 @@ public: ...@@ -431,7 +431,7 @@ public:
* @brief Get the data format enum. * @brief Get the data format enum.
* @return constexpr DataFormat * @return constexpr DataFormat
*/ */
constexpr DataFormat dataFormat() const noexcept { return mDataFormat; } const DataFormat& dataFormat() const noexcept { return mDataFormat; }
/** /**
* @brief Set the DataType of the Tensor and converts data * @brief Set the DataType of the Tensor and converts data
...@@ -462,13 +462,13 @@ public: ...@@ -462,13 +462,13 @@ public:
* data is copy-transposed. * data is copy-transposed.
*/ */
void setDataFormat(const DataFormat df, bool copyTrans = true) { void setDataFormat(const DataFormat df, bool copyTrans = true) {
if (!copyTrans || df == dataFormat()) { if (!copyTrans || df == dataFormat() || df == DataFormat::Default || dataFormat() == DataFormat::Default) {
mDataFormat = df; mDataFormat = df;
return; return;
} }
const auto transpose = getDataFormatTranspose(dataFormat(), df); const auto transpose = getPermutationMapping(dataFormat(), df);
if (mImpl) { if (mImpl) {
copyTranspose(*this, transpose); copyTranspose(*this, transpose);
} else { } else {
...@@ -476,7 +476,7 @@ public: ...@@ -476,7 +476,7 @@ public:
for (std::size_t i = 0; i < dims().size(); ++i) { for (std::size_t i = 0; i < dims().size(); ++i) {
newDims.push_back(dims()[transpose[i]]); newDims.push_back(dims()[transpose[i]]);
} }
std::vector<std::size_t> newStrides(dims().size(), 1); std::vector<std::size_t> newStrides(dims().size(), 1);
for (size_t i = 0; i < dims().size(); ++i) { for (size_t i = 0; i < dims().size(); ++i) {
for (size_t j = i + 1; j < dims().size(); ++j) { for (size_t j = i + 1; j < dims().size(); ++j) {
...@@ -486,9 +486,10 @@ public: ...@@ -486,9 +486,10 @@ public:
mDims = std::move(newDims); mDims = std::move(newDims);
mStrides = std::move(newStrides); mStrides = std::move(newStrides);
} }
mDataFormat = df; mDataFormat = df;
} }
/** /**
* @brief Get the Impl object * @brief Get the Impl object
* @return constexpr const std::shared_ptr<TensorImpl>& * @return constexpr const std::shared_ptr<TensorImpl>&
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment