From 70a0be1f8b0949975cc3791294d589b610313715 Mon Sep 17 00:00:00 2001 From: Jerome Hue <jerome.hue@cea.fr> Date: Wed, 12 Feb 2025 16:41:55 +0100 Subject: [PATCH] Add unit test for Tensor.repeat() --- unit_tests/data/Test_Tensor.cpp | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/unit_tests/data/Test_Tensor.cpp b/unit_tests/data/Test_Tensor.cpp index bfdc1a6b9..28f3a5fde 100644 --- a/unit_tests/data/Test_Tensor.cpp +++ b/unit_tests/data/Test_Tensor.cpp @@ -505,6 +505,35 @@ TEST_CASE("[core/data] Tensor(other)", "[Tensor][extract][zeros][print]") { } } + SECTION("repeat") { + Tensor tensor = Array2D<int, 2, 3>{{{1, 2, 3}, + {4, 5, 6}}}; + const int repeatTimes = 4; + + Tensor repeated; + REQUIRE_NOTHROW(repeated = tensor.repeat(repeatTimes)); + + // The expected shape after repeating is {repeatTimes, 2, 3} + std::vector<DimSize_t> expectedDims = {static_cast<DimSize_t>(repeatTimes), 2, 3}; + CHECK(repeated.dims() == expectedDims); + + // For each repetition along the new dimension, extract the slice and verify + // that it matches the original tensor + for (int i = 0; i < repeatTimes; ++i) { + Tensor slice; + REQUIRE_NOTHROW(slice = repeated.extract({static_cast<std::size_t>(i)})); + CHECK(slice.dims() == tensor.dims()); + + // Compare slice with original tensor elementwise + for (std::size_t idx = 0; idx < tensor.size(); ++idx) { + int expectedVal = tensor.get<int>(idx); + int sliceVal = slice.get<int>(idx); + INFO("Mismatch in repetition " << i << " at flat index " << idx); + CHECK(sliceVal == expectedVal); + } + } + } + // print, toString, SECTION("Pretty printing for debug") { Tensor x{}; -- GitLab