Skip to content
Snippets Groups Projects
Commit 5f4f8814 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

FIxed dims in unit tests

parent 41b356b3
No related branches found
No related tags found
2 merge requests!118v0.4.0,!109Add LRN operator
Pipeline #60177 passed
......@@ -124,7 +124,9 @@ TEST_CASE("[cpu/operator] GlobalAveragePooling",
dims_in[1]; // averaging per channel : 1 addition per element in
// the channel + 1 division this for every batch
// create out nb_elems
std::vector<std::size_t> dims_out{dims_in[0], dims_in[1]};
std::vector<std::size_t> dims_out(dims_in.size(), 1);
dims_out[0] = dims_in[0];
dims_out[1] = dims_in[1];
const std::size_t out_nb_elems =
std::accumulate(dims_out.cbegin(), dims_out.cend(), std::size_t(1),
std::multiplies<std::size_t>());
......@@ -151,7 +153,7 @@ TEST_CASE("[cpu/operator] GlobalAveragePooling",
T0->getImpl()->setRawPtr(array0, in_nb_elems);
// results
Tres->resize(dims_in);
Tres->resize(dims_out);
Tres->getImpl()->setRawPtr(result, out_nb_elems);
op->forwardDims();
......@@ -192,7 +194,9 @@ TEST_CASE("[cpu/operator] GlobalAveragePooling",
// the channel + 1 division this for every batch
// create out nb_elems
std::vector<std::size_t> dims_out{dims_in[0], dims_in[1]};
std::vector<std::size_t> dims_out(dims_in.size(), 1);
dims_out[0] = dims_in[0];
dims_out[1] = dims_in[1];
const std::size_t out_nb_elems =
std::accumulate(dims_out.cbegin(), dims_out.cend(),
std::size_t(1), std::multiplies<std::size_t>());
......@@ -222,7 +226,7 @@ TEST_CASE("[cpu/operator] GlobalAveragePooling",
T0->getImpl()->setRawPtr(array0, in_nb_elems);
// results
Tres->resize(dims_in);
Tres->resize(dims_out);
Tres->getImpl()->setRawPtr(result, out_nb_elems);
op->forwardDims();
......@@ -253,7 +257,9 @@ TEST_CASE("[cpu/operator] GlobalAveragePooling",
SECTION("2D_img") {
const std::vector<DimSize_t> in_dims{batch_size, channels, height,
width};
const std::vector<DimSize_t> out_dims{batch_size, channels};
std::vector<std::size_t> out_dims(in_dims.size(), 1);
out_dims[0] = in_dims[0];
out_dims[1] = in_dims[1];
DimSize_t in_nb_elems = batch_size * channels * height * width;
DimSize_t out_nb_elems = batch_size * channels;
number_of_operation +=
......@@ -348,7 +354,7 @@ TEST_CASE("[cpu/operator] GlobalAveragePooling",
T0->getImpl()->setRawPtr(input, in_nb_elems);
// results
Tres->resize(in_dims);
Tres->resize(out_dims);
Tres->getImpl()->setRawPtr(result, out_nb_elems);
op->forwardDims();
start = std::chrono::system_clock::now();
......@@ -368,7 +374,9 @@ TEST_CASE("[cpu/operator] GlobalAveragePooling",
SECTION("3D_img") {
const std::vector<DimSize_t> in_dims{batch_size, channels, height,
width, depth};
const std::vector<DimSize_t> out_dims{batch_size, channels};
std::vector<std::size_t> out_dims(in_dims.size(), 1);
out_dims[0] = in_dims[0];
out_dims[1] = in_dims[1];
DimSize_t in_nb_elems =
batch_size * channels * height * width * depth;
number_of_operation +=
......@@ -535,7 +543,7 @@ TEST_CASE("[cpu/operator] GlobalAveragePooling",
T0->getImpl()->setRawPtr(input, in_nb_elems);
// results
Tres->resize(in_dims);
Tres->resize(out_dims);
Tres->getImpl()->setRawPtr(result, out_nb_elems);
op->forwardDims();
start = std::chrono::system_clock::now();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment