From e5cc0ce2639b85ec6cad3dfa0a4b67444fc8d812 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Tue, 13 Feb 2024 17:47:18 +0100
Subject: [PATCH] First working version of LSTM inference

---
 unit_tests/operator/Test_MetaOperator.cpp | 59 ++++++++++++++++++++++-
 1 file changed, 58 insertions(+), 1 deletion(-)

diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp
index 2b63e1e1..4f74c5ce 100644
--- a/unit_tests/operator/Test_MetaOperator.cpp
+++ b/unit_tests/operator/Test_MetaOperator.cpp
@@ -191,7 +191,7 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
   }
   
     SECTION("LSTM(forward)") {
-        auto myLSTM = LSTM(32, 64, 16, "ltsm");
+        auto myLSTM = LSTM(32, 64, 16, true, "ltsm");
         auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator());
 
         auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraph();
@@ -221,4 +221,61 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
         auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraphScheduler();
         microGraphScheduler->saveSchedulingDiagram("lstm_scheduling");
     }
+
+    SECTION("LSTM(forward_values)") {
+        auto myLSTM = LSTM(2, 3, 1, true, "ltsm");
+        auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator());
+
+        auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraph();
+        microGraph->save("lstm", false, false);
+
+        REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
+        REQUIRE(myLSTM->nbData() == 3);
+        REQUIRE(myLSTM->nbOutputs() == 2);
+
+        std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
+            Array2D<float, 3, 2>{{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}});
+        std::shared_ptr<Tensor> myInit = std::make_shared<Tensor>(
+            Array2D<float, 3, 3>{{{0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}}});
+        std::shared_ptr<Tensor> myInitW = std::make_shared<Tensor>(
+            Array2D<float, 3, 2>{{{0.1, 0.1}, {0.1, 0.1}, {0.1, 0.1}}});
+        std::shared_ptr<Tensor> myInitR = std::make_shared<Tensor>(
+            Array2D<float, 3, 3>{{{0.1, 0.1, 0.1}, {0.1, 0.1, 0.1}, {0.1, 0.1, 0.1}}});
+
+        op->associateInput(0, myInput);
+        op->associateInput(17, myInit);
+        op->associateInput(18, myInit);
+
+        op->computeOutputDims();
+        REQUIRE(op->outputDimsForwarded());
+        microGraph->save("lstm_values_dims", false, true);
+
+        // Weights X
+        op->associateInput(1, myInitW);
+        op->associateInput(2, myInitW);
+        op->associateInput(3, myInitW);
+        op->associateInput(4, myInitW);
+        // Weights H
+        op->associateInput(5, myInitR);
+        op->associateInput(6, myInitR);
+        op->associateInput(7, myInitR);
+        op->associateInput(8, myInitR);
+
+        op->setDataType(DataType::Float32);
+        op->setBackend("cpu");
+
+        std::shared_ptr<Tensor> myHiddenState = std::make_shared<Tensor>(
+                Array2D<float, 3, 3>{{{0.0952412, 0.0952412, 0.0952412},
+                                     {0.25606447, 0.25606447, 0.25606447},
+                                     {0.40323776, 0.40323776, 0.40323776}}});
+
+        op->forward();
+        auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraphScheduler();
+        microGraphScheduler->saveSchedulingDiagram("lstm_values_scheduling");
+
+        op->getOutput(1)->print();
+        myHiddenState->print();
+
+        REQUIRE(approxEq<float>(*(op->getOutput(1)), *myHiddenState));
+    }
 }
\ No newline at end of file
-- 
GitLab