diff --git a/unit_tests/operator/Test_GlobalAveragePooling_Op.cpp b/unit_tests/operator/Test_GlobalAveragePooling_Op.cpp index 15c714b63c2b86e156b43cdaec390ddf60eb7353..29e6f0a56d34935ab5c061897a27b902e03db790 100644 --- a/unit_tests/operator/Test_GlobalAveragePooling_Op.cpp +++ b/unit_tests/operator/Test_GlobalAveragePooling_Op.cpp @@ -70,12 +70,14 @@ TEST_CASE("[core/operator] GlobalAveragePooling_Op(forwardDims)", for (uint16_t i = 0; i < nb_dims; ++i) { dims[i] = dimsDist(gen) + 1; } - std::vector<DimSize_t> dims_out{dims[0], dims[1]}; + std::vector<DimSize_t> dims_out(nb_dims, 1); + dims_out[0] = dims[0]; + dims_out[1] = dims[1]; input_T->resize(dims); op->setInput(0, input_T); REQUIRE_NOTHROW(op->forwardDims()); REQUIRE(op->getOutput(0)->dims() == dims_out); - REQUIRE((op->getOutput(0)->dims().size()) == static_cast<size_t>(2)); + REQUIRE((op->getOutput(0)->dims().size()) == nb_dims); } } }