From c1ecff210a0c3ad8774475356ef64abae5979cd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gr=C3=A9goire=20KUBLER?= <gregoire.kubler@proton.me> Date: Thu, 5 Sep 2024 15:28:51 +0200 Subject: [PATCH] fix : output dims of squeeze were not properly created --- unit_tests/operator/Test_Squeeze_Op.cpp | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/unit_tests/operator/Test_Squeeze_Op.cpp b/unit_tests/operator/Test_Squeeze_Op.cpp index 61a06aa50..471a1dcd1 100644 --- a/unit_tests/operator/Test_Squeeze_Op.cpp +++ b/unit_tests/operator/Test_Squeeze_Op.cpp @@ -388,15 +388,14 @@ TEST_CASE("[core/operator] Squeeze(forward)", "[Squeeze][forward]") { int i = 0; std::vector<DimSize_t> dims_out; dims_out.reserve(dims_in.size()); - std::copy_if(dims_in.begin(), dims_in.end(), std::back_inserter(dims_out), - [&dims_to_squeeze, &i](DimSize_t dim) { - bool ok = dim != 1 || - !std::binary_search(dims_to_squeeze.begin(), - dims_to_squeeze.end(), i); - i++; // incrementing counter since C++ has not enumerate - // fctn (until C++23) - return ok; - }); + for (DimIdx_t i = 0; i < dims_in.size(); ++i) { + if (dims_in[i] == 1 && + std::find(dims_to_squeeze.begin(), dims_to_squeeze.end(), i) != + dims_to_squeeze.end()) { + continue; + } + dims_out.push_back(dims_in[i]); + } CHECK(op->forwardDims()); CHECK(op->getOutput(0)->dims() == dims_out); -- GitLab