Skip to content
Snippets Groups Projects
Commit 70a0be1f authored by Jerome Hue's avatar Jerome Hue Committed by Maxence Naud
Browse files

Add unit test for Tensor.repeat()

parent 764504df
No related branches found
No related tags found
1 merge request!351feat: add rate spikegen for snns
...@@ -505,6 +505,35 @@ TEST_CASE("[core/data] Tensor(other)", "[Tensor][extract][zeros][print]") { ...@@ -505,6 +505,35 @@ TEST_CASE("[core/data] Tensor(other)", "[Tensor][extract][zeros][print]") {
} }
} }
SECTION("repeat") {
Tensor tensor = Array2D<int, 2, 3>{{{1, 2, 3},
{4, 5, 6}}};
const int repeatTimes = 4;
Tensor repeated;
REQUIRE_NOTHROW(repeated = tensor.repeat(repeatTimes));
// The expected shape after repeating is {repeatTimes, 2, 3}
std::vector<DimSize_t> expectedDims = {static_cast<DimSize_t>(repeatTimes), 2, 3};
CHECK(repeated.dims() == expectedDims);
// For each repetition along the new dimension, extract the slice and verify
// that it matches the original tensor
for (int i = 0; i < repeatTimes; ++i) {
Tensor slice;
REQUIRE_NOTHROW(slice = repeated.extract({static_cast<std::size_t>(i)}));
CHECK(slice.dims() == tensor.dims());
// Compare slice with original tensor elementwise
for (std::size_t idx = 0; idx < tensor.size(); ++idx) {
int expectedVal = tensor.get<int>(idx);
int sliceVal = slice.get<int>(idx);
INFO("Mismatch in repetition " << i << " at flat index " << idx);
CHECK(sliceVal == expectedVal);
}
}
}
// print, toString, // print, toString,
SECTION("Pretty printing for debug") { SECTION("Pretty printing for debug") {
Tensor x{}; Tensor x{};
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment