diff --git a/include/aidge/operator/MetaOperatorDefs.hpp b/include/aidge/operator/MetaOperatorDefs.hpp index 5bb184b808e0a9d685879e53554ff3be500f5717..9597b533c14b27d282985b13cd8e1199ed5360a8 100644 --- a/include/aidge/operator/MetaOperatorDefs.hpp +++ b/include/aidge/operator/MetaOperatorDefs.hpp @@ -260,6 +260,17 @@ inline std::shared_ptr<Node> PaddedMaxPooling( return PaddedMaxPooling(to_array(kernel_dims), name, stride_dims, padding_dims, ceil_mode); } +/** + * @brief Creates an LSTM (Long Short-Term Memory) operation as a MetaOperator. + * + * This function creates an LSTM operation as a MetaOperator for use in graph-based computation. + * + * @param[in] seq_length The length of the input sequence. + * @return A shared pointer to the MetaOperator_Op representing the LSTM operation. + */ +std::shared_ptr<MetaOperator_Op> LSTM_Op(DimSize_t seq_length, + const std::string &name = ""); + /** * @brief Creates an LSTM (Long Short-Term Memory) operator. * @@ -278,16 +289,6 @@ std::shared_ptr<Node> LSTM(DimSize_t in_channels, bool noBias = false, const std::string &name = ""); -/** - * @brief Creates an LSTM (Long Short-Term Memory) operation as a MetaOperator. - * - * This function creates an LSTM operation as a MetaOperator for use in graph-based computation. - * - * @param[in] seq_length The length of the input sequence. - * @return A shared pointer to the MetaOperator_Op representing the LSTM operation. - */ -std::shared_ptr<MetaOperator_Op> LSTM_Op(DimSize_t seq_length); - std::shared_ptr<MetaOperator_Op> LeakyOp(); std::shared_ptr<Node> Leaky(const int nbTimeSteps, const float beta, diff --git a/python_binding/operator/pybind_MetaOperatorDefs.cpp b/python_binding/operator/pybind_MetaOperatorDefs.cpp index b2811fbaab2b6cd33dc2b105f0044cd8a5edbbc7..35f3d21341fbb529d692a71e597c3b2b76c8426e 100644 --- a/python_binding/operator/pybind_MetaOperatorDefs.cpp +++ b/python_binding/operator/pybind_MetaOperatorDefs.cpp @@ -176,7 +176,8 @@ void declare_LSTMOp(py::module &m) { py::arg("nobias") = false, py::arg("name") = ""); m.def("LSTMOp", &LSTM_Op, - py::arg("seq_length")); + py::arg("seq_length"), + py::arg("name") = ""); } void declare_LeakyOp(py::module &m) { diff --git a/src/operator/MetaOperatorDefs/LSTM.cpp b/src/operator/MetaOperatorDefs/LSTM.cpp index 22c0469b34b52670a910f63604d02f3f8bf6eab7..c7fbe8a16aa727782b4d1b8ecb0b6d8a29c50a86 100644 --- a/src/operator/MetaOperatorDefs/LSTM.cpp +++ b/src/operator/MetaOperatorDefs/LSTM.cpp @@ -23,11 +23,8 @@ namespace Aidge { -std::shared_ptr<Node> LSTM(const DimSize_t inChannel, - const DimSize_t hiddenChannel, - const DimSize_t seqLength, - bool noBias, - const std::string& name) +std::shared_ptr<MetaOperator_Op> LSTM_Op(const DimSize_t seqLength, + const std::string& name) { // Construct micro-graph auto input = Identity((!name.empty()) ? name + "_input" : ""); @@ -113,7 +110,18 @@ std::shared_ptr<Node> LSTM(const DimSize_t inChannel, {hiddenState, 1}, {cellState, 1}}); microGraph->setOrderedOutputs({{hiddenState, 0}, {cellState, 0}}); - auto metaOp = MetaOperator("LSTM", microGraph, {}, name); + return std::make_shared<MetaOperator_Op>("LSTM", microGraph); +} + +std::shared_ptr<Node> LSTM(const DimSize_t inChannel, + const DimSize_t hiddenChannel, + const DimSize_t seqLength, + bool noBias, + const std::string& name) +{ + auto op = LSTM_Op(seqLength, name); + auto metaOp = std::make_shared<Node>(op, name); + op->setUpperNode(metaOp); addProducer(metaOp, 1, {hiddenChannel, inChannel}, "wi"); addProducer(metaOp, 2, {hiddenChannel, inChannel}, "wo"); addProducer(metaOp, 3, {hiddenChannel, inChannel}, "wf"); @@ -135,93 +143,4 @@ std::shared_ptr<Node> LSTM(const DimSize_t inChannel, return metaOp; } -std::shared_ptr<MetaOperator_Op> LSTM_Op(const DimSize_t seqLength) -{ - // Construct micro-graph - auto input = Identity(""); - auto hiddenState = Memorize(seqLength, ""); - auto cellState = Memorize(seqLength, ""); - auto add = Add(""); - - // Forget gate - auto forgetGateX = std::make_shared<Node>(std::make_shared<FC_Op>(), ""); - input->addChild(forgetGateX, 0, 0); - auto forgetGateH = std::make_shared<Node>(std::make_shared<FC_Op>(), ""); - hiddenState->addChild(forgetGateH, 1, 0); - auto forgetGate = Add(""); - forgetGateX->addChild(forgetGate, 0, 0); - forgetGateH->addChild(forgetGate, 0, 1); - auto forgetGateAct = Sigmoid(""); - auto forgetGateMul = Mul(""); - forgetGate->addChild(forgetGateAct, 0, 0); - forgetGateAct->addChild(forgetGateMul, 0, 0); - forgetGateMul->addChild(add, 0, 0); - cellState->addChild(forgetGateMul, 1, 1); - - // Input gate - auto inputGateX = std::make_shared<Node>(std::make_shared<FC_Op>(), ""); - input->addChild(inputGateX, 0, 0); - auto inputGateH = std::make_shared<Node>(std::make_shared<FC_Op>(), ""); - hiddenState->addChild(inputGateH, 1, 0); - auto inputGate = Add(""); - inputGateX->addChild(inputGate, 0, 0); - inputGateH->addChild(inputGate, 0, 1); - auto inputGateAct = Sigmoid(""); - auto inputGateMul = Mul(""); - inputGate->addChild(inputGateAct, 0, 0); - inputGateAct->addChild(inputGateMul, 0, 0); - inputGateMul->addChild(add, 0, 1); - - // Candidate for cell update - auto cellCandidateX = std::make_shared<Node>(std::make_shared<FC_Op>(), ""); - input->addChild(cellCandidateX, 0, 0); - auto cellCandidateH = std::make_shared<Node>(std::make_shared<FC_Op>(), ""); - hiddenState->addChild(cellCandidateH, 1, 0); - auto cellCandidate = Add(""); - cellCandidateX->addChild(cellCandidate, 0, 0); - cellCandidateH->addChild(cellCandidate, 0, 1); - auto cellCandidateAct = Tanh(""); - cellCandidate->addChild(cellCandidateAct, 0, 0); - cellCandidateAct->addChild(inputGateMul, 0, 1); - - // Output gate - auto outputGateX = std::make_shared<Node>(std::make_shared<FC_Op>(), ""); - input->addChild(outputGateX, 0, 0); - auto outputGateH = std::make_shared<Node>(std::make_shared<FC_Op>(), ""); - hiddenState->addChild(outputGateH, 1, 0); - auto outputGate = Add(""); - outputGateX->addChild(outputGate, 0, 0); - outputGateH->addChild(outputGate, 0, 1); - auto outputGateAct = Sigmoid(""); - auto outputGateMul = Mul(""); - outputGate->addChild(outputGateAct, 0, 0); - outputGateAct->addChild(outputGateMul, 0, 0); - - // Updated cell state to help determine new hidden state - auto cellUpdatedAct = Tanh(""); - add->addChild(cellUpdatedAct, 0, 0); - cellUpdatedAct->addChild(outputGateMul, 0, 1); - outputGateMul->addChild(hiddenState, 0, 0); - add->addChild(cellState, 0, 0); - - std::shared_ptr<GraphView> microGraph = std::make_shared<GraphView>(); - microGraph->add(input); - microGraph->add({hiddenState, cellState, add, - forgetGateX, forgetGateH, forgetGate, forgetGateAct, forgetGateMul, - inputGateX, inputGateH, inputGate, inputGateAct, inputGateMul, - cellCandidateX, cellCandidateH, cellCandidate, cellCandidateAct, - outputGateX, outputGateH, outputGate, outputGateAct, outputGateMul, - cellUpdatedAct}, false); - - microGraph->setOrderedInputs({{input, 0}, - {inputGateX, 1}, {outputGateX, 1}, {forgetGateX, 1}, {cellCandidateX, 1}, - {inputGateH, 1}, {outputGateH, 1}, {forgetGateH, 1}, {cellCandidateH, 1}, - {inputGateX, 2}, {outputGateX, 2}, {forgetGateX, 2}, {cellCandidateX, 2}, - {inputGateH, 2}, {outputGateH, 2}, {forgetGateH, 2}, {cellCandidateH, 2}, - {hiddenState, 1}, {cellState, 1}}); - microGraph->setOrderedOutputs({{hiddenState, 0}, {cellState, 0}}); - - return std::make_shared<MetaOperator_Op>("LSTM", microGraph); -} - } // namespace Aidge