From b05788821a509cc7450b3804591bbe8bf7962eac Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Wed, 6 Nov 2024 10:08:34 +0000
Subject: [PATCH] fix: update 'GlobalAvgPooling::forwardDims()' unit tests

---
 unit_tests/operator/Test_GlobalAveragePooling_Op.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/unit_tests/operator/Test_GlobalAveragePooling_Op.cpp b/unit_tests/operator/Test_GlobalAveragePooling_Op.cpp
index 15c714b63..29e6f0a56 100644
--- a/unit_tests/operator/Test_GlobalAveragePooling_Op.cpp
+++ b/unit_tests/operator/Test_GlobalAveragePooling_Op.cpp
@@ -70,12 +70,14 @@ TEST_CASE("[core/operator] GlobalAveragePooling_Op(forwardDims)",
           for (uint16_t i = 0; i < nb_dims; ++i) {
             dims[i] = dimsDist(gen) + 1;
           }
-          std::vector<DimSize_t> dims_out{dims[0], dims[1]};
+          std::vector<DimSize_t> dims_out(nb_dims, 1);
+          dims_out[0] = dims[0];
+          dims_out[1] = dims[1];
           input_T->resize(dims);
           op->setInput(0, input_T);
           REQUIRE_NOTHROW(op->forwardDims());
           REQUIRE(op->getOutput(0)->dims() == dims_out);
-          REQUIRE((op->getOutput(0)->dims().size()) == static_cast<size_t>(2));
+          REQUIRE((op->getOutput(0)->dims().size()) == nb_dims);
         }
       }
     }
-- 
GitLab