From 2f0e6ffeee33d68efe8cca6f05326d85cd8a6fc1 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Thu, 5 Dec 2024 15:47:26 +0000
Subject: [PATCH] add some tests for Tensor addition operator and subtraction
 operator

---
 unit_tests/data/Test_TensorImpl.cpp | 96 +++++++++++++++++++++++++++++
 1 file changed, 96 insertions(+)

diff --git a/unit_tests/data/Test_TensorImpl.cpp b/unit_tests/data/Test_TensorImpl.cpp
index 5db37c96..fd938f10 100644
--- a/unit_tests/data/Test_TensorImpl.cpp
+++ b/unit_tests/data/Test_TensorImpl.cpp
@@ -193,4 +193,100 @@ TEST_CASE("Test division of Tensors","[TensorImpl][Div]") {
     Tensor T3(T1.dims());
     REQUIRE_THROWS(T0 / T3);
 }
+
+TEST_CASE("Tensor arithmetic operators", "[Tensor][Operator][CPU]") {
+    SECTION("Addition") {
+        const Tensor t = Array1D<std::int32_t, 5>{1,2,3,4,5};
+        const Tensor t2 = Array1D<std::int32_t, 5>{10,20,30,40,50};
+        const Tensor t3 = Tensor(std::int32_t(3));
+
+        SECTION("operator+") {
+            auto a = t.clone();
+            auto b = t2.clone();
+            auto c = t3.clone();
+
+            // simple addition
+            auto r1 = a + b;
+            const Tensor expected_res_simple = Array1D<std::int32_t, 5>{11,22,33,44,55};
+
+            // input tensors are not modified
+            REQUIRE(a == t);
+            REQUIRE(b == t2);
+            // result is right
+            REQUIRE(r1 == expected_res_simple);
+
+            // simple addition of arithmetic value
+            auto r2 = a + 10;
+            const Tensor expected_res_simple_arithmetic = Array1D<std::int32_t, 5>{11,12,13,14,15};
+
+            // input tensors are not modified
+            REQUIRE(a == t);
+            // result is right
+            REQUIRE(r2 == expected_res_simple_arithmetic);
+
+
+            // chained addition a+b+c
+            auto r3 = a + b + c;
+            const Tensor expected_res_chained = Array1D<std::int32_t, 5>{14,25,36,47,58};
+
+            // input Tensors are not modified
+            REQUIRE(a == t);
+            REQUIRE(b == t2);
+            REQUIRE(c == t3);
+            // result is right
+            REQUIRE(r3 == expected_res_chained);
+        }
+        SECTION("operator+=") {
+            auto a = t.clone();
+            auto b = t2.clone();
+
+            a += b;
+            const Tensor expected_res = Array1D<std::int32_t, 5>{11,22,33,44,55};
+
+            // input tensors are not modified
+            REQUIRE(b == t2);
+            // result is right
+            REQUIRE(a == expected_res);
+
+            // simple addition of arithmetic value
+            a = t.clone();
+            a += 10;
+            const Tensor expected_res_arithmetic = Array1D<std::int32_t, 5>{11,12,13,14,15};
+
+            // result is right
+            REQUIRE(a == expected_res_arithmetic);
+        }
+    }
+    SECTION("Substraction") {
+        const Tensor t = Array1D<std::int32_t, 5>{1,2,3,4,5};
+        const Tensor t2 = Tensor(std::int32_t(3));
+
+        SECTION("operator-") {
+            auto a = t.clone();
+            auto b = t2.clone();
+
+            // simple substraction
+            auto r1 = a - b;
+            const Tensor expected_res_simple = Array1D<std::int32_t, 5>{-2,-1,0,1,2};
+
+            // input tensors are not modified
+            REQUIRE(a == t);
+            REQUIRE(b == t2);
+            // result is right
+            REQUIRE(r1 == expected_res_simple);
+        }
+        SECTION("operator-=") {
+            auto a = t.clone();
+            auto b = t2.clone();
+
+            a -= b;
+            const Tensor expected_res = Array1D<std::int32_t, 5>{-2,-1,0,1,2};
+
+            // input tensors are not modified
+            REQUIRE(b == t2);
+            // result is right
+            REQUIRE(a == expected_res);
+        }
+    }
+}
 } // namespace Aidge
-- 
GitLab