Skip to content
Snippets Groups Projects

[Fix] vit_operators

Merged Maxence Naud requested to merge vit_operators into dev
1 file
+ 2
2
Compare changes
  • Side-by-side
  • Inline
@@ -79,12 +79,12 @@ class ReduceMean_Op : public OperatorTensor,
@@ -79,12 +79,12 @@ class ReduceMean_Op : public OperatorTensor,
if (val < 0)
if (val < 0)
val+=static_cast<std::int32_t>(getInput(0)->nbDims());
val+=static_cast<std::int32_t>(getInput(0)->nbDims());
});
});
std::sort(axes.cbegin(), axes.cend());
std::sort(axes.begin(), axes.end());
// build output dimensions
// build output dimensions
std::vector<DimSize_t> outDims = getInput(0)->dims();
std::vector<DimSize_t> outDims = getInput(0)->dims();
if (this->template getAttr<ReduceMeanAttr::KeepDims>()) {
if (this->template getAttr<ReduceMeanAttr::KeepDims>()) {
std::for_each(axes.begin(), axes.end(), [&] (const std::int32_t& val) { outDims[val] = 1; });
std::for_each(axes.begin(), axes.end(), [&outDims] (const std::int32_t& val) { outDims[val] = 1; });
}
}
else {
else {
for (auto it = axes.crbegin(); it != axes.crend(); ++it)
for (auto it = axes.crbegin(); it != axes.crend(); ++it)
Loading