From 668c906d8c772ed6c5765cd0c50b73f2849d96e8 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Thu, 11 Apr 2024 17:19:10 +0200
Subject: [PATCH] Fixed required forwardDims() because input dims change

---
 unit_tests/operator/Test_DivImpl.cpp                  | 3 +++
 unit_tests/operator/Test_GlobalAveragePoolingImpl.cpp | 4 ++++
 unit_tests/operator/Test_MatMulImpl.cpp               | 4 ++++
 unit_tests/operator/Test_MulImpl.cpp                  | 3 +++
 unit_tests/operator/Test_PowImpl.cpp                  | 3 +++
 unit_tests/operator/Test_SubImpl.cpp                  | 3 +++
 6 files changed, 20 insertions(+)

diff --git a/unit_tests/operator/Test_DivImpl.cpp b/unit_tests/operator/Test_DivImpl.cpp
index 552882ac..5d7dfdf1 100644
--- a/unit_tests/operator/Test_DivImpl.cpp
+++ b/unit_tests/operator/Test_DivImpl.cpp
@@ -103,6 +103,7 @@ TEST_CASE("[cpu/operator] Div", "[Div][CPU]") {
                 Tres->resize(dims);
                 Tres -> getImpl() -> setRawPtr(result, nb_elements);
 
+                op->forwardDims();
                 start = std::chrono::system_clock::now();
                 myDiv->forward();
                 end = std::chrono::system_clock::now();
@@ -195,6 +196,7 @@ TEST_CASE("[cpu/operator] Div", "[Div][CPU]") {
                 Tres -> getImpl() -> setRawPtr(result, dimsOut[0]*dimsOut[1]*dimsOut[2]*dimsOut[3]);
 
                 // compute result
+                op->forwardDims();
                 start = std::chrono::system_clock::now();
                 myDiv->forward();
                 end = std::chrono::system_clock::now();
@@ -289,6 +291,7 @@ TEST_CASE("[cpu/operator] Div", "[Div][CPU]") {
                 Tres -> getImpl() -> setRawPtr(result, dimsOut[0]*dimsOut[1]*dimsOut[2]*dimsOut[3]);
 
                 // compute result
+                op->forwardDims();
                 start = std::chrono::system_clock::now();
                 myDiv->forward();
                 end = std::chrono::system_clock::now();
diff --git a/unit_tests/operator/Test_GlobalAveragePoolingImpl.cpp b/unit_tests/operator/Test_GlobalAveragePoolingImpl.cpp
index 9c357dc9..43903100 100644
--- a/unit_tests/operator/Test_GlobalAveragePoolingImpl.cpp
+++ b/unit_tests/operator/Test_GlobalAveragePoolingImpl.cpp
@@ -154,6 +154,7 @@ TEST_CASE("[cpu/operator] GlobalAveragePooling",
         Tres->resize(dims_out);
         Tres->getImpl()->setRawPtr(result, out_nb_elems);
 
+        op->forwardDims();
         start = std::chrono::system_clock::now();
         REQUIRE_NOTHROW(globAvgPool->forward());
         end = std::chrono::system_clock::now();
@@ -224,6 +225,7 @@ TEST_CASE("[cpu/operator] GlobalAveragePooling",
           Tres->resize(dims_out);
           Tres->getImpl()->setRawPtr(result, out_nb_elems);
 
+          op->forwardDims();
           start = std::chrono::system_clock::now();
           REQUIRE_NOTHROW(globAvgPool->forward());
           end = std::chrono::system_clock::now();
@@ -348,6 +350,7 @@ TEST_CASE("[cpu/operator] GlobalAveragePooling",
           // results
           Tres->resize(out_dims);
           Tres->getImpl()->setRawPtr(result, out_nb_elems);
+          op->forwardDims();
           start = std::chrono::system_clock::now();
           REQUIRE_NOTHROW(globAvgPool->forward());
           end = std::chrono::system_clock::now();
@@ -534,6 +537,7 @@ TEST_CASE("[cpu/operator] GlobalAveragePooling",
           // results
           Tres->resize(out_dims);
           Tres->getImpl()->setRawPtr(result, out_nb_elems);
+          op->forwardDims();
           start = std::chrono::system_clock::now();
           REQUIRE_NOTHROW(globAvgPool->forward());
           end = std::chrono::system_clock::now();
diff --git a/unit_tests/operator/Test_MatMulImpl.cpp b/unit_tests/operator/Test_MatMulImpl.cpp
index 414b38f0..8a1e589f 100644
--- a/unit_tests/operator/Test_MatMulImpl.cpp
+++ b/unit_tests/operator/Test_MatMulImpl.cpp
@@ -94,6 +94,7 @@ TEST_CASE("[cpu/operator] MatMul(forward)", "[MatMul][CPU]") {
             op->associateInput(1, T2);
             op->setDataType(DataType::Float32);
             op->setBackend("cpu");
+            op->forwardDims();
             start = std::chrono::system_clock::now();
             myMatMul->forward();
             end = std::chrono::system_clock::now();
@@ -157,6 +158,7 @@ TEST_CASE("[cpu/operator] MatMul(forward)", "[MatMul][CPU]") {
             op->associateInput(1, T2);
             op->setDataType(DataType::Float32);
             op->setBackend("cpu");
+            op->forwardDims();
             start = std::chrono::system_clock::now();
             myMatMul->forward();
             end = std::chrono::system_clock::now();
@@ -223,6 +225,7 @@ TEST_CASE("[cpu/operator] MatMul(forward)", "[MatMul][CPU]") {
             op->associateInput(1, T2);
             op->setDataType(DataType::Float32);
             op->setBackend("cpu");
+            op->forwardDims();
             start = std::chrono::system_clock::now();
             myMatMul->forward();
             end = std::chrono::system_clock::now();
@@ -255,6 +258,7 @@ TEST_CASE("[cpu/operator] MatMul(forward)", "[MatMul][CPU]") {
 
         op->setDataType(DataType::Float32);
         op->setBackend("cpu");
+        op->forwardDims();
         myMatMul->forward();
 
     }
diff --git a/unit_tests/operator/Test_MulImpl.cpp b/unit_tests/operator/Test_MulImpl.cpp
index d1bd7d0d..9d592d31 100644
--- a/unit_tests/operator/Test_MulImpl.cpp
+++ b/unit_tests/operator/Test_MulImpl.cpp
@@ -103,6 +103,7 @@ TEST_CASE("[cpu/operator] Mul", "[Mul][CPU]") {
                 Tres->resize(dims);
                 Tres -> getImpl() -> setRawPtr(result, nb_elements);
 
+                op->forwardDims();
                 start = std::chrono::system_clock::now();
                 myMul->forward();
                 end = std::chrono::system_clock::now();
@@ -195,6 +196,7 @@ TEST_CASE("[cpu/operator] Mul", "[Mul][CPU]") {
                 Tres -> getImpl() -> setRawPtr(result, dimsOut[0]*dimsOut[1]*dimsOut[2]*dimsOut[3]);
 
                 // compute result
+                op->forwardDims();
                 start = std::chrono::system_clock::now();
                 myMul->forward();
                 end = std::chrono::system_clock::now();
@@ -289,6 +291,7 @@ TEST_CASE("[cpu/operator] Mul", "[Mul][CPU]") {
                 Tres -> getImpl() -> setRawPtr(result, dimsOut[0]*dimsOut[1]*dimsOut[2]*dimsOut[3]);
 
                 // compute result
+                op->forwardDims();
                 start = std::chrono::system_clock::now();
                 myMul->forward();
                 end = std::chrono::system_clock::now();
diff --git a/unit_tests/operator/Test_PowImpl.cpp b/unit_tests/operator/Test_PowImpl.cpp
index 7f1ee39f..3b85defb 100644
--- a/unit_tests/operator/Test_PowImpl.cpp
+++ b/unit_tests/operator/Test_PowImpl.cpp
@@ -104,6 +104,7 @@ TEST_CASE("[cpu/operator] Pow", "[Pow][CPU]") {
                 Tres->resize(dims);
                 Tres -> getImpl() -> setRawPtr(result, nb_elements);
 
+                op->forwardDims();
                 start = std::chrono::system_clock::now();
                 myPow->forward();
                 end = std::chrono::system_clock::now();
@@ -196,6 +197,7 @@ TEST_CASE("[cpu/operator] Pow", "[Pow][CPU]") {
                 Tres -> getImpl() -> setRawPtr(result, dimsOut[0]*dimsOut[1]*dimsOut[2]*dimsOut[3]);
 
                 // compute result
+                op->forwardDims();
                 start = std::chrono::system_clock::now();
                 myPow->forward();
                 end = std::chrono::system_clock::now();
@@ -290,6 +292,7 @@ TEST_CASE("[cpu/operator] Pow", "[Pow][CPU]") {
                 Tres -> getImpl() -> setRawPtr(result, dimsOut[0]*dimsOut[1]*dimsOut[2]*dimsOut[3]);
 
                 // compute result
+                op->forwardDims();
                 start = std::chrono::system_clock::now();
                 myPow->forward();
                 end = std::chrono::system_clock::now();
diff --git a/unit_tests/operator/Test_SubImpl.cpp b/unit_tests/operator/Test_SubImpl.cpp
index ef818e8d..44666ae6 100644
--- a/unit_tests/operator/Test_SubImpl.cpp
+++ b/unit_tests/operator/Test_SubImpl.cpp
@@ -103,6 +103,7 @@ TEST_CASE("[cpu/operator] Sub", "[Sub][CPU]") {
                 Tres->resize(dims);
                 Tres -> getImpl() -> setRawPtr(result, nb_elements);
 
+                op->forwardDims();
                 start = std::chrono::system_clock::now();
                 mySub->forward();
                 end = std::chrono::system_clock::now();
@@ -195,6 +196,7 @@ TEST_CASE("[cpu/operator] Sub", "[Sub][CPU]") {
                 Tres -> getImpl() -> setRawPtr(result, dimsOut[0]*dimsOut[1]*dimsOut[2]*dimsOut[3]);
 
                 // compute result
+                op->forwardDims();
                 start = std::chrono::system_clock::now();
                 mySub->forward();
                 end = std::chrono::system_clock::now();
@@ -289,6 +291,7 @@ TEST_CASE("[cpu/operator] Sub", "[Sub][CPU]") {
                 Tres -> getImpl() -> setRawPtr(result, dimsOut[0]*dimsOut[1]*dimsOut[2]*dimsOut[3]);
 
                 // compute result
+                op->forwardDims();
                 start = std::chrono::system_clock::now();
                 mySub->forward();
                 end = std::chrono::system_clock::now();
-- 
GitLab