diff --git a/unit_tests/operator/Test_GlobalAveragePoolingImpl.cpp b/unit_tests/operator/Test_GlobalAveragePoolingImpl.cpp index d5f2065b624de431b43edef9a83bf079905129dd..51b4366d2a0332562f8aac78a303c9304bde2529 100644 --- a/unit_tests/operator/Test_GlobalAveragePoolingImpl.cpp +++ b/unit_tests/operator/Test_GlobalAveragePoolingImpl.cpp @@ -151,7 +151,7 @@ TEST_CASE("[cpu/operator] GlobalAveragePooling", T0->getImpl()->setRawPtr(array0, in_nb_elems); // results - Tres->resize(dims_out); + Tres->resize(dims_in); Tres->getImpl()->setRawPtr(result, out_nb_elems); op->forwardDims(); @@ -222,7 +222,7 @@ TEST_CASE("[cpu/operator] GlobalAveragePooling", T0->getImpl()->setRawPtr(array0, in_nb_elems); // results - Tres->resize(dims_out); + Tres->resize(dims_in); Tres->getImpl()->setRawPtr(result, out_nb_elems); op->forwardDims(); @@ -348,7 +348,7 @@ TEST_CASE("[cpu/operator] GlobalAveragePooling", T0->getImpl()->setRawPtr(input, in_nb_elems); // results - Tres->resize(out_dims); + Tres->resize(in_dims); Tres->getImpl()->setRawPtr(result, out_nb_elems); op->forwardDims(); start = std::chrono::system_clock::now(); @@ -535,7 +535,7 @@ TEST_CASE("[cpu/operator] GlobalAveragePooling", T0->getImpl()->setRawPtr(input, in_nb_elems); // results - Tres->resize(out_dims); + Tres->resize(in_dims); Tres->getImpl()->setRawPtr(result, out_nb_elems); op->forwardDims(); start = std::chrono::system_clock::now();