From 1cdefe5502a003c958a865eeac108aa1096d1051 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Tue, 13 Feb 2024 17:21:51 +0100 Subject: [PATCH] Fixed Producers for LSTM --- include/aidge/graph/GraphView.hpp | 4 +- include/aidge/operator/MetaOperatorDefs.hpp | 45 ++++++++++++++----- .../operator/pybind_MetaOperatorDefs.cpp | 4 +- unit_tests/operator/Test_MetaOperator.cpp | 10 ++--- 4 files changed, 45 insertions(+), 18 deletions(-) diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 77a759d9b..d3b022463 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -140,8 +140,8 @@ public: void setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs); void setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs); - inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedInputs() { return mInputNodes; }; - inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedOutputs() { return mOutputNodes; }; + inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedInputs() const { return mInputNodes; }; + inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedOutputs() const { return mOutputNodes; }; /** * @brief List outside data input connections of the GraphView. diff --git a/include/aidge/operator/MetaOperatorDefs.hpp b/include/aidge/operator/MetaOperatorDefs.hpp index 2c7044f55..d51844e3b 100644 --- a/include/aidge/operator/MetaOperatorDefs.hpp +++ b/include/aidge/operator/MetaOperatorDefs.hpp @@ -147,6 +147,7 @@ inline std::shared_ptr<Node> PaddedMaxPooling( inline std::shared_ptr<Node> LSTM(DimSize_t in_channels, DimSize_t hidden_channels, DimSize_t seq_length, + bool noBias = false, const std::string& name = "") { // Construct micro-graph @@ -156,9 +157,9 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels, auto add = Add(2, (!name.empty()) ? name + "_add" : ""); // Forget gate - auto forgetGateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_forgetGateX" : ""); + auto forgetGateX = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_forgetGateX" : ""); input->addChild(forgetGateX, 0, 0); - auto forgetGateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_forgetGateH" : ""); + auto forgetGateH = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_forgetGateH" : ""); hiddenState->addChild(forgetGateH, 1, 0); auto forgetGate = Add(2, (!name.empty()) ? name + "_forgetGate" : ""); forgetGateX->addChild(forgetGate, 0, 0); @@ -171,9 +172,9 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels, cellState->addChild(forgetGateMul, 1, 1); // Input gate - auto inputGateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_inputGateX" : ""); + auto inputGateX = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_inputGateX" : ""); input->addChild(inputGateX, 0, 0); - auto inputGateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_inputGateH" : ""); + auto inputGateH = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_inputGateH" : ""); hiddenState->addChild(inputGateH, 1, 0); auto inputGate = Add(2, (!name.empty()) ? name + "_inputGate" : ""); inputGateX->addChild(inputGate, 0, 0); @@ -185,9 +186,9 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels, inputGateMul->addChild(add, 0, 1); // Candidate for cell update - auto cellCandidateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_cellCandidateX" : ""); + auto cellCandidateX = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_cellCandidateX" : ""); input->addChild(cellCandidateX, 0, 0); - auto cellCandidateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_cellCandidateH" : ""); + auto cellCandidateH = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_cellCandidateH" : ""); hiddenState->addChild(cellCandidateH, 1, 0); auto cellCandidate = Add(2, (!name.empty()) ? name + "_cellCandidate" : ""); cellCandidateX->addChild(cellCandidate, 0, 0); @@ -197,9 +198,9 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels, cellCandidateAct->addChild(inputGateMul, 0, 1); // Output gate - auto outputGateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_outputGateX" : ""); + auto outputGateX = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_outputGateX" : ""); input->addChild(outputGateX, 0, 0); - auto outputGateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_outputGateH" : ""); + auto outputGateH = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_outputGateH" : ""); hiddenState->addChild(outputGateH, 1, 0); auto outputGate = Add(2, (!name.empty()) ? name + "_outputGate" : ""); outputGateX->addChild(outputGate, 0, 0); @@ -223,9 +224,33 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels, inputGateX, inputGateH, inputGate, inputGateAct, inputGateMul, cellCandidateX, cellCandidateH, cellCandidate, cellCandidateAct, outputGateX, outputGateH, outputGate, outputGateAct, outputGateMul, - cellUpdatedAct}); + cellUpdatedAct}, false); - return MetaOperator("LSTM", microGraph, name); + microGraph->setOrderedInputs({{input, 0}, + {inputGateX, 1}, {outputGateX, 1}, {forgetGateX, 1}, {cellCandidateX, 1}, + {inputGateH, 1}, {outputGateH, 1}, {forgetGateH, 1}, {cellCandidateH, 1}, + {inputGateX, 2}, {outputGateX, 2}, {forgetGateX, 2}, {cellCandidateX, 2}, + {inputGateH, 2}, {outputGateH, 2}, {forgetGateH, 2}, {cellCandidateH, 2}, + {hiddenState, 1}, {cellState, 1}}); + + auto metaOp = MetaOperator("LSTM", microGraph, name); + addProducer(metaOp, 1, {hidden_channels, in_channels}, "wi"); + addProducer(metaOp, 2, {hidden_channels, in_channels}, "wo"); + addProducer(metaOp, 3, {hidden_channels, in_channels}, "wf"); + addProducer(metaOp, 4, {hidden_channels, in_channels}, "wc"); + addProducer(metaOp, 5, {hidden_channels, hidden_channels}, "ri"); + addProducer(metaOp, 6, {hidden_channels, hidden_channels}, "ro"); + addProducer(metaOp, 7, {hidden_channels, hidden_channels}, "rf"); + addProducer(metaOp, 8, {hidden_channels, hidden_channels}, "rc"); + addProducer(metaOp, 9, {(noBias ? 0 : hidden_channels)}, "wbi"); + addProducer(metaOp, 10, {(noBias ? 0 : hidden_channels)}, "wbo"); + addProducer(metaOp, 11, {(noBias ? 0 : hidden_channels)}, "wbf"); + addProducer(metaOp, 12, {(noBias ? 0 : hidden_channels)}, "wbc"); + addProducer(metaOp, 13, {(noBias ? 0 : hidden_channels)}, "rbi"); + addProducer(metaOp, 14, {(noBias ? 0 : hidden_channels)}, "rbo"); + addProducer(metaOp, 15, {(noBias ? 0 : hidden_channels)}, "rbf"); + addProducer(metaOp, 16, {(noBias ? 0 : hidden_channels)}, "rbc"); + return metaOp; } } // namespace Aidge diff --git a/python_binding/operator/pybind_MetaOperatorDefs.cpp b/python_binding/operator/pybind_MetaOperatorDefs.cpp index 9eb11b6ba..11c3db681 100644 --- a/python_binding/operator/pybind_MetaOperatorDefs.cpp +++ b/python_binding/operator/pybind_MetaOperatorDefs.cpp @@ -112,12 +112,14 @@ void declare_LSTMOp(py::module &m) { m.def("LSTM", [](DimSize_t in_channels, DimSize_t hidden_channels, DimSize_t seq_length, + bool nobias, const std::string& name) { - return LSTM(in_channels, hidden_channels, seq_length, name); + return LSTM(in_channels, hidden_channels, seq_length, nobias, name); }, py::arg("in_channels"), py::arg("hidden_channels"), py::arg("seq_length"), + py::arg("nobias") = false, py::arg("name") = ""); } diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp index 5263625d5..328492d43 100644 --- a/unit_tests/operator/Test_MetaOperator.cpp +++ b/unit_tests/operator/Test_MetaOperator.cpp @@ -53,13 +53,13 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") { } SECTION("LSTM") { - auto myLSTM = LSTM(32, 64, 16, "ltsm"); + auto myLSTM = LSTM(32, 64, 16, true, "ltsm"); auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator()); auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraph(); microGraph->save("lstm", false, false); - REQUIRE(myLSTM->nbInputs() == 3); + REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8); REQUIRE(myLSTM->nbData() == 3); REQUIRE(myLSTM->nbOutputs() == 2); @@ -69,12 +69,12 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") { myInit->resize({1, 64}); op->associateInput(0, myInput); - op->associateInput(1, myInit); - op->associateInput(2, myInit); + op->associateInput(17, myInit); + op->associateInput(18, myInit); op->computeOutputDims(); + microGraph->save("lstm_dims", true, true); REQUIRE(op->outputDimsForwarded()); - microGraph->save("lstm_dims", false, false); //op->updateConsummerProducer(); // require implementation //auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraphScheduler(); -- GitLab