Skip to content
Snippets Groups Projects
Commit 5efc366b authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'remove_lstm_redundancy' into 'dev'

Removed code redundancy

See merge request !324
parents 7142cbf9 69720354
No related branches found
No related tags found
2 merge requests!333[Fix] Attribute snake case,!324Removed code redundancy
Pipeline #64828 passed
......@@ -260,6 +260,17 @@ inline std::shared_ptr<Node> PaddedMaxPooling(
return PaddedMaxPooling(to_array(kernel_dims), name, stride_dims, padding_dims, ceil_mode);
}
/**
* @brief Creates an LSTM (Long Short-Term Memory) operation as a MetaOperator.
*
* This function creates an LSTM operation as a MetaOperator for use in graph-based computation.
*
* @param[in] seq_length The length of the input sequence.
* @return A shared pointer to the MetaOperator_Op representing the LSTM operation.
*/
std::shared_ptr<MetaOperator_Op> LSTM_Op(DimSize_t seq_length,
const std::string &name = "");
/**
* @brief Creates an LSTM (Long Short-Term Memory) operator.
*
......@@ -278,16 +289,6 @@ std::shared_ptr<Node> LSTM(DimSize_t in_channels,
bool noBias = false,
const std::string &name = "");
/**
* @brief Creates an LSTM (Long Short-Term Memory) operation as a MetaOperator.
*
* This function creates an LSTM operation as a MetaOperator for use in graph-based computation.
*
* @param[in] seq_length The length of the input sequence.
* @return A shared pointer to the MetaOperator_Op representing the LSTM operation.
*/
std::shared_ptr<MetaOperator_Op> LSTM_Op(DimSize_t seq_length);
std::shared_ptr<MetaOperator_Op> LeakyOp();
std::shared_ptr<Node> Leaky(const int nbTimeSteps,
const float beta,
......
......@@ -176,7 +176,8 @@ void declare_LSTMOp(py::module &m) {
py::arg("nobias") = false,
py::arg("name") = "");
m.def("LSTMOp", &LSTM_Op,
py::arg("seq_length"));
py::arg("seq_length"),
py::arg("name") = "");
}
void declare_LeakyOp(py::module &m) {
......
......@@ -23,11 +23,8 @@
namespace Aidge {
std::shared_ptr<Node> LSTM(const DimSize_t inChannel,
const DimSize_t hiddenChannel,
const DimSize_t seqLength,
bool noBias,
const std::string& name)
std::shared_ptr<MetaOperator_Op> LSTM_Op(const DimSize_t seqLength,
const std::string& name)
{
// Construct micro-graph
auto input = Identity((!name.empty()) ? name + "_input" : "");
......@@ -113,7 +110,18 @@ std::shared_ptr<Node> LSTM(const DimSize_t inChannel,
{hiddenState, 1}, {cellState, 1}});
microGraph->setOrderedOutputs({{hiddenState, 0}, {cellState, 0}});
auto metaOp = MetaOperator("LSTM", microGraph, {}, name);
return std::make_shared<MetaOperator_Op>("LSTM", microGraph);
}
std::shared_ptr<Node> LSTM(const DimSize_t inChannel,
const DimSize_t hiddenChannel,
const DimSize_t seqLength,
bool noBias,
const std::string& name)
{
auto op = LSTM_Op(seqLength, name);
auto metaOp = std::make_shared<Node>(op, name);
op->setUpperNode(metaOp);
addProducer(metaOp, 1, {hiddenChannel, inChannel}, "wi");
addProducer(metaOp, 2, {hiddenChannel, inChannel}, "wo");
addProducer(metaOp, 3, {hiddenChannel, inChannel}, "wf");
......@@ -135,93 +143,4 @@ std::shared_ptr<Node> LSTM(const DimSize_t inChannel,
return metaOp;
}
std::shared_ptr<MetaOperator_Op> LSTM_Op(const DimSize_t seqLength)
{
// Construct micro-graph
auto input = Identity("");
auto hiddenState = Memorize(seqLength, "");
auto cellState = Memorize(seqLength, "");
auto add = Add("");
// Forget gate
auto forgetGateX = std::make_shared<Node>(std::make_shared<FC_Op>(), "");
input->addChild(forgetGateX, 0, 0);
auto forgetGateH = std::make_shared<Node>(std::make_shared<FC_Op>(), "");
hiddenState->addChild(forgetGateH, 1, 0);
auto forgetGate = Add("");
forgetGateX->addChild(forgetGate, 0, 0);
forgetGateH->addChild(forgetGate, 0, 1);
auto forgetGateAct = Sigmoid("");
auto forgetGateMul = Mul("");
forgetGate->addChild(forgetGateAct, 0, 0);
forgetGateAct->addChild(forgetGateMul, 0, 0);
forgetGateMul->addChild(add, 0, 0);
cellState->addChild(forgetGateMul, 1, 1);
// Input gate
auto inputGateX = std::make_shared<Node>(std::make_shared<FC_Op>(), "");
input->addChild(inputGateX, 0, 0);
auto inputGateH = std::make_shared<Node>(std::make_shared<FC_Op>(), "");
hiddenState->addChild(inputGateH, 1, 0);
auto inputGate = Add("");
inputGateX->addChild(inputGate, 0, 0);
inputGateH->addChild(inputGate, 0, 1);
auto inputGateAct = Sigmoid("");
auto inputGateMul = Mul("");
inputGate->addChild(inputGateAct, 0, 0);
inputGateAct->addChild(inputGateMul, 0, 0);
inputGateMul->addChild(add, 0, 1);
// Candidate for cell update
auto cellCandidateX = std::make_shared<Node>(std::make_shared<FC_Op>(), "");
input->addChild(cellCandidateX, 0, 0);
auto cellCandidateH = std::make_shared<Node>(std::make_shared<FC_Op>(), "");
hiddenState->addChild(cellCandidateH, 1, 0);
auto cellCandidate = Add("");
cellCandidateX->addChild(cellCandidate, 0, 0);
cellCandidateH->addChild(cellCandidate, 0, 1);
auto cellCandidateAct = Tanh("");
cellCandidate->addChild(cellCandidateAct, 0, 0);
cellCandidateAct->addChild(inputGateMul, 0, 1);
// Output gate
auto outputGateX = std::make_shared<Node>(std::make_shared<FC_Op>(), "");
input->addChild(outputGateX, 0, 0);
auto outputGateH = std::make_shared<Node>(std::make_shared<FC_Op>(), "");
hiddenState->addChild(outputGateH, 1, 0);
auto outputGate = Add("");
outputGateX->addChild(outputGate, 0, 0);
outputGateH->addChild(outputGate, 0, 1);
auto outputGateAct = Sigmoid("");
auto outputGateMul = Mul("");
outputGate->addChild(outputGateAct, 0, 0);
outputGateAct->addChild(outputGateMul, 0, 0);
// Updated cell state to help determine new hidden state
auto cellUpdatedAct = Tanh("");
add->addChild(cellUpdatedAct, 0, 0);
cellUpdatedAct->addChild(outputGateMul, 0, 1);
outputGateMul->addChild(hiddenState, 0, 0);
add->addChild(cellState, 0, 0);
std::shared_ptr<GraphView> microGraph = std::make_shared<GraphView>();
microGraph->add(input);
microGraph->add({hiddenState, cellState, add,
forgetGateX, forgetGateH, forgetGate, forgetGateAct, forgetGateMul,
inputGateX, inputGateH, inputGate, inputGateAct, inputGateMul,
cellCandidateX, cellCandidateH, cellCandidate, cellCandidateAct,
outputGateX, outputGateH, outputGate, outputGateAct, outputGateMul,
cellUpdatedAct}, false);
microGraph->setOrderedInputs({{input, 0},
{inputGateX, 1}, {outputGateX, 1}, {forgetGateX, 1}, {cellCandidateX, 1},
{inputGateH, 1}, {outputGateH, 1}, {forgetGateH, 1}, {cellCandidateH, 1},
{inputGateX, 2}, {outputGateX, 2}, {forgetGateX, 2}, {cellCandidateX, 2},
{inputGateH, 2}, {outputGateH, 2}, {forgetGateH, 2}, {cellCandidateH, 2},
{hiddenState, 1}, {cellState, 1}});
microGraph->setOrderedOutputs({{hiddenState, 0}, {cellState, 0}});
return std::make_shared<MetaOperator_Op>("LSTM", microGraph);
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment