From 70a0be1f8b0949975cc3791294d589b610313715 Mon Sep 17 00:00:00 2001
From: Jerome Hue <jerome.hue@cea.fr>
Date: Wed, 12 Feb 2025 16:41:55 +0100
Subject: [PATCH] Add unit test for Tensor.repeat()

---
 unit_tests/data/Test_Tensor.cpp | 29 +++++++++++++++++++++++++++++
 1 file changed, 29 insertions(+)

diff --git a/unit_tests/data/Test_Tensor.cpp b/unit_tests/data/Test_Tensor.cpp
index bfdc1a6b9..28f3a5fde 100644
--- a/unit_tests/data/Test_Tensor.cpp
+++ b/unit_tests/data/Test_Tensor.cpp
@@ -505,6 +505,35 @@ TEST_CASE("[core/data] Tensor(other)", "[Tensor][extract][zeros][print]") {
         }
     }
 
+    SECTION("repeat") {
+        Tensor tensor = Array2D<int, 2, 3>{{{1, 2, 3},
+                                            {4, 5, 6}}};
+        const int repeatTimes = 4;
+        
+        Tensor repeated;
+        REQUIRE_NOTHROW(repeated = tensor.repeat(repeatTimes));
+        
+        // The expected shape after repeating is {repeatTimes, 2, 3}
+        std::vector<DimSize_t> expectedDims = {static_cast<DimSize_t>(repeatTimes), 2, 3};
+        CHECK(repeated.dims() == expectedDims);
+        
+        // For each repetition along the new dimension, extract the slice and verify
+        // that it matches the original tensor
+        for (int i = 0; i < repeatTimes; ++i) {
+            Tensor slice;
+            REQUIRE_NOTHROW(slice = repeated.extract({static_cast<std::size_t>(i)}));
+            CHECK(slice.dims() == tensor.dims());
+            
+            // Compare slice with original tensor elementwise
+            for (std::size_t idx = 0; idx < tensor.size(); ++idx) {
+                int expectedVal = tensor.get<int>(idx);
+                int sliceVal = slice.get<int>(idx);
+                INFO("Mismatch in repetition " << i << " at flat index " << idx);
+                CHECK(sliceVal == expectedVal);
+            }
+        }
+    }
+
     // print, toString,
     SECTION("Pretty printing for debug") {
         Tensor x{};
-- 
GitLab