From 1cdefe5502a003c958a865eeac108aa1096d1051 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Tue, 13 Feb 2024 17:21:51 +0100
Subject: [PATCH] Fixed Producers for LSTM

---
 include/aidge/graph/GraphView.hpp             |  4 +-
 include/aidge/operator/MetaOperatorDefs.hpp   | 45 ++++++++++++++-----
 .../operator/pybind_MetaOperatorDefs.cpp      |  4 +-
 unit_tests/operator/Test_MetaOperator.cpp     | 10 ++---
 4 files changed, 45 insertions(+), 18 deletions(-)

diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp
index 77a759d9b..d3b022463 100644
--- a/include/aidge/graph/GraphView.hpp
+++ b/include/aidge/graph/GraphView.hpp
@@ -140,8 +140,8 @@ public:
     void setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs);
     void setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs);
 
-    inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedInputs() { return mInputNodes; };
-    inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedOutputs() { return mOutputNodes; };
+    inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedInputs() const { return mInputNodes; };
+    inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedOutputs() const { return mOutputNodes; };
 
     /**
      * @brief List outside data input connections of the GraphView.
diff --git a/include/aidge/operator/MetaOperatorDefs.hpp b/include/aidge/operator/MetaOperatorDefs.hpp
index 2c7044f55..d51844e3b 100644
--- a/include/aidge/operator/MetaOperatorDefs.hpp
+++ b/include/aidge/operator/MetaOperatorDefs.hpp
@@ -147,6 +147,7 @@ inline std::shared_ptr<Node> PaddedMaxPooling(
 inline std::shared_ptr<Node> LSTM(DimSize_t in_channels,
                                   DimSize_t hidden_channels,
                                   DimSize_t seq_length,
+                                  bool noBias = false,
                                   const std::string& name = "")
 {
     // Construct micro-graph
@@ -156,9 +157,9 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels,
     auto add = Add(2, (!name.empty()) ? name + "_add" : "");
 
     // Forget gate
-    auto forgetGateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_forgetGateX" : "");
+    auto forgetGateX = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_forgetGateX" : "");
     input->addChild(forgetGateX, 0, 0);
-    auto forgetGateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_forgetGateH" : "");
+    auto forgetGateH = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_forgetGateH" : "");
     hiddenState->addChild(forgetGateH, 1, 0);
     auto forgetGate = Add(2, (!name.empty()) ? name + "_forgetGate" : "");
     forgetGateX->addChild(forgetGate, 0, 0);
@@ -171,9 +172,9 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels,
     cellState->addChild(forgetGateMul, 1, 1);
 
     // Input gate
-    auto inputGateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_inputGateX" : "");
+    auto inputGateX = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_inputGateX" : "");
     input->addChild(inputGateX, 0, 0);
-    auto inputGateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_inputGateH" : "");
+    auto inputGateH = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_inputGateH" : "");
     hiddenState->addChild(inputGateH, 1, 0);
     auto inputGate = Add(2, (!name.empty()) ? name + "_inputGate" : "");
     inputGateX->addChild(inputGate, 0, 0);
@@ -185,9 +186,9 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels,
     inputGateMul->addChild(add, 0, 1);
 
     // Candidate for cell update
-    auto cellCandidateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_cellCandidateX" : "");
+    auto cellCandidateX = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_cellCandidateX" : "");
     input->addChild(cellCandidateX, 0, 0);
-    auto cellCandidateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_cellCandidateH" : "");
+    auto cellCandidateH = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_cellCandidateH" : "");
     hiddenState->addChild(cellCandidateH, 1, 0);
     auto cellCandidate = Add(2, (!name.empty()) ? name + "_cellCandidate" : "");
     cellCandidateX->addChild(cellCandidate, 0, 0);
@@ -197,9 +198,9 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels,
     cellCandidateAct->addChild(inputGateMul, 0, 1);
 
     // Output gate
-    auto outputGateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_outputGateX" : "");
+    auto outputGateX = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_outputGateX" : "");
     input->addChild(outputGateX, 0, 0);
-    auto outputGateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_outputGateH" : "");
+    auto outputGateH = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_outputGateH" : "");
     hiddenState->addChild(outputGateH, 1, 0);
     auto outputGate = Add(2, (!name.empty()) ? name + "_outputGate" : "");
     outputGateX->addChild(outputGate, 0, 0);
@@ -223,9 +224,33 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels,
         inputGateX, inputGateH, inputGate, inputGateAct, inputGateMul,
         cellCandidateX, cellCandidateH, cellCandidate, cellCandidateAct,
         outputGateX, outputGateH, outputGate, outputGateAct, outputGateMul,
-        cellUpdatedAct});
+        cellUpdatedAct}, false);
 
-    return MetaOperator("LSTM", microGraph, name);
+    microGraph->setOrderedInputs({{input, 0},
+        {inputGateX, 1}, {outputGateX, 1}, {forgetGateX, 1}, {cellCandidateX, 1},
+        {inputGateH, 1}, {outputGateH, 1}, {forgetGateH, 1}, {cellCandidateH, 1},
+        {inputGateX, 2}, {outputGateX, 2}, {forgetGateX, 2}, {cellCandidateX, 2},
+        {inputGateH, 2}, {outputGateH, 2}, {forgetGateH, 2}, {cellCandidateH, 2},
+        {hiddenState, 1}, {cellState, 1}});
+
+    auto metaOp = MetaOperator("LSTM", microGraph, name);
+    addProducer(metaOp, 1, {hidden_channels, in_channels}, "wi");
+    addProducer(metaOp, 2, {hidden_channels, in_channels}, "wo");
+    addProducer(metaOp, 3, {hidden_channels, in_channels}, "wf");
+    addProducer(metaOp, 4, {hidden_channels, in_channels}, "wc");
+    addProducer(metaOp, 5, {hidden_channels, hidden_channels}, "ri");
+    addProducer(metaOp, 6, {hidden_channels, hidden_channels}, "ro");
+    addProducer(metaOp, 7, {hidden_channels, hidden_channels}, "rf");
+    addProducer(metaOp, 8, {hidden_channels, hidden_channels}, "rc");
+    addProducer(metaOp, 9, {(noBias ? 0 : hidden_channels)}, "wbi");
+    addProducer(metaOp, 10, {(noBias ? 0 : hidden_channels)}, "wbo");
+    addProducer(metaOp, 11, {(noBias ? 0 : hidden_channels)}, "wbf");
+    addProducer(metaOp, 12, {(noBias ? 0 : hidden_channels)}, "wbc");
+    addProducer(metaOp, 13, {(noBias ? 0 : hidden_channels)}, "rbi");
+    addProducer(metaOp, 14, {(noBias ? 0 : hidden_channels)}, "rbo");
+    addProducer(metaOp, 15, {(noBias ? 0 : hidden_channels)}, "rbf");
+    addProducer(metaOp, 16, {(noBias ? 0 : hidden_channels)}, "rbc");
+    return metaOp;
 }
 }  // namespace Aidge
 
diff --git a/python_binding/operator/pybind_MetaOperatorDefs.cpp b/python_binding/operator/pybind_MetaOperatorDefs.cpp
index 9eb11b6ba..11c3db681 100644
--- a/python_binding/operator/pybind_MetaOperatorDefs.cpp
+++ b/python_binding/operator/pybind_MetaOperatorDefs.cpp
@@ -112,12 +112,14 @@ void declare_LSTMOp(py::module &m) {
   m.def("LSTM", [](DimSize_t in_channels,
                     DimSize_t hidden_channels,
                     DimSize_t seq_length,
+                    bool nobias,
                     const std::string& name)
     {
-        return LSTM(in_channels, hidden_channels, seq_length, name);
+        return LSTM(in_channels, hidden_channels, seq_length, nobias, name);
     }, py::arg("in_channels"),
        py::arg("hidden_channels"),
        py::arg("seq_length"),
+       py::arg("nobias") = false,
        py::arg("name") = "");
 }
 
diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp
index 5263625d5..328492d43 100644
--- a/unit_tests/operator/Test_MetaOperator.cpp
+++ b/unit_tests/operator/Test_MetaOperator.cpp
@@ -53,13 +53,13 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") {
     }
 
     SECTION("LSTM") {
-        auto myLSTM = LSTM(32, 64, 16, "ltsm");
+        auto myLSTM = LSTM(32, 64, 16, true, "ltsm");
         auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator());
 
         auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraph();
         microGraph->save("lstm", false, false);
 
-        REQUIRE(myLSTM->nbInputs() == 3);
+        REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
         REQUIRE(myLSTM->nbData() == 3);
         REQUIRE(myLSTM->nbOutputs() == 2);
 
@@ -69,12 +69,12 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") {
         myInit->resize({1, 64});
 
         op->associateInput(0, myInput);
-        op->associateInput(1, myInit);
-        op->associateInput(2, myInit);
+        op->associateInput(17, myInit);
+        op->associateInput(18, myInit);
 
         op->computeOutputDims();
+        microGraph->save("lstm_dims", true, true);
         REQUIRE(op->outputDimsForwarded());
-        microGraph->save("lstm_dims", false, false);
 
         //op->updateConsummerProducer();  // require implementation
         //auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraphScheduler();
-- 
GitLab