Skip to content
Snippets Groups Projects

[Feat] ForwardDims Conv for NCHW and NHWC

Merged Wissam Boussella requested to merge wboussella/aidge_core:nhwc_for_conv into dev
3 unresolved threads
Files
4
@@ -462,12 +462,39 @@ public:
* data is copy-transposed.
*/
void setDataFormat(const DataFormat df, bool copyTrans = true) {
if (mImpl && copyTrans && (dataFormat() != df) && df != DataFormat::Default && dataFormat() != DataFormat::Default) {
copyTranspose(*this, getDataFormatTranspose(dataFormat(), df));
if (!copyTrans || df == dataFormat()) {
mDataFormat = df;
return;
}
if ((df == DataFormat::Default && dataFormat() == DataFormat::Default) ||
Please register or sign in to reply
(df == DataFormat::NCHW && dataFormat() == DataFormat::NCHW)) {
mDataFormat = df;
return;
}
const auto transpose = getDataFormatTranspose(dataFormat(), df);
if (mImpl) {
copyTranspose(*this, transpose);
} else {
std::vector<DimSize_t> newDims;
for (std::size_t i = 0; i < dims().size(); ++i) {
newDims.push_back(dims()[transpose[i]]);
}
std::vector<std::size_t> newStrides(dims().size(), 1);
for (size_t i = 0; i < dims().size(); ++i) {
for (size_t j = i + 1; j < dims().size(); ++j) {
newStrides[i] *= newDims[j];
}
}
mDims = std::move(newDims);
mStrides = std::move(newStrides);
}
mDataFormat = df;
}
/**
* @brief Get the Impl object
* @return constexpr const std::shared_ptr<TensorImpl>&
Loading