diff --git a/unit_tests/data/Test_Tensor.cpp b/unit_tests/data/Test_Tensor.cpp index bfdc1a6b9c058b348942e9c29a77ac4d6db5086f..28f3a5fdec1836d5d5f5c2a42375e49bc4f1dcc5 100644 --- a/unit_tests/data/Test_Tensor.cpp +++ b/unit_tests/data/Test_Tensor.cpp @@ -505,6 +505,35 @@ TEST_CASE("[core/data] Tensor(other)", "[Tensor][extract][zeros][print]") { } } + SECTION("repeat") { + Tensor tensor = Array2D<int, 2, 3>{{{1, 2, 3}, + {4, 5, 6}}}; + const int repeatTimes = 4; + + Tensor repeated; + REQUIRE_NOTHROW(repeated = tensor.repeat(repeatTimes)); + + // The expected shape after repeating is {repeatTimes, 2, 3} + std::vector<DimSize_t> expectedDims = {static_cast<DimSize_t>(repeatTimes), 2, 3}; + CHECK(repeated.dims() == expectedDims); + + // For each repetition along the new dimension, extract the slice and verify + // that it matches the original tensor + for (int i = 0; i < repeatTimes; ++i) { + Tensor slice; + REQUIRE_NOTHROW(slice = repeated.extract({static_cast<std::size_t>(i)})); + CHECK(slice.dims() == tensor.dims()); + + // Compare slice with original tensor elementwise + for (std::size_t idx = 0; idx < tensor.size(); ++idx) { + int expectedVal = tensor.get<int>(idx); + int sliceVal = slice.get<int>(idx); + INFO("Mismatch in repetition " << i << " at flat index " << idx); + CHECK(sliceVal == expectedVal); + } + } + } + // print, toString, SECTION("Pretty printing for debug") { Tensor x{};