Skip to content
Snippets Groups Projects
Commit 4e75c3ea authored by Olivier BICHLER's avatar Olivier BICHLER Committed by Maxence Naud
Browse files

Removed code redundancy

parent 53e848ad
No related branches found
No related tags found
2 merge requests!329[Fix] Log no wrap for Path,!325[Upd] Patch v0.5.1
......@@ -260,6 +260,17 @@ inline std::shared_ptr<Node> PaddedMaxPooling(
return PaddedMaxPooling(to_array(kernel_dims), name, stride_dims, padding_dims, ceil_mode);
}
/**
* @brief Creates an LSTM (Long Short-Term Memory) operation as a MetaOperator.
*
* This function creates an LSTM operation as a MetaOperator for use in graph-based computation.
*
* @param[in] seq_length The length of the input sequence.
* @return A shared pointer to the MetaOperator_Op representing the LSTM operation.
*/
std::shared_ptr<MetaOperator_Op> LSTM_Op(DimSize_t seq_length,
const std::string &name = "");
/**
* @brief Creates an LSTM (Long Short-Term Memory) operator.
*
......@@ -278,16 +289,6 @@ std::shared_ptr<Node> LSTM(DimSize_t in_channels,
bool noBias = false,
const std::string &name = "");
/**
* @brief Creates an LSTM (Long Short-Term Memory) operation as a MetaOperator.
*
* This function creates an LSTM operation as a MetaOperator for use in graph-based computation.
*
* @param[in] seq_length The length of the input sequence.
* @return A shared pointer to the MetaOperator_Op representing the LSTM operation.
*/
std::shared_ptr<MetaOperator_Op> LSTM_Op(DimSize_t seq_length);
std::shared_ptr<MetaOperator_Op> LeakyOp();
std::shared_ptr<Node> Leaky(const int nbTimeSteps,
const float beta,
......
......@@ -176,7 +176,8 @@ void declare_LSTMOp(py::module &m) {
py::arg("nobias") = false,
py::arg("name") = "");
m.def("LSTMOp", &LSTM_Op,
py::arg("seq_length"));
py::arg("seq_length"),
py::arg("name") = "");
}
void declare_LeakyOp(py::module &m) {
......
......@@ -23,11 +23,8 @@
namespace Aidge {
std::shared_ptr<Node> LSTM(const DimSize_t inChannel,
const DimSize_t hiddenChannel,
const DimSize_t seqLength,
bool noBias,
const std::string& name)
std::shared_ptr<MetaOperator_Op> LSTM_Op(const DimSize_t seqLength,
const std::string& name)
{
// Construct micro-graph
auto input = Identity((!name.empty()) ? name + "_input" : "");
......@@ -113,7 +110,18 @@ std::shared_ptr<Node> LSTM(const DimSize_t inChannel,
{hiddenState, 1}, {cellState, 1}});
microGraph->setOrderedOutputs({{hiddenState, 0}, {cellState, 0}});
auto metaOp = MetaOperator("LSTM", microGraph, {}, name);
return std::make_shared<MetaOperator_Op>("LSTM", microGraph);
}
std::shared_ptr<Node> LSTM(const DimSize_t inChannel,
const DimSize_t hiddenChannel,
const DimSize_t seqLength,
bool noBias,
const std::string& name)
{
auto op = LSTM_Op(seqLength, name);
auto metaOp = std::make_shared<Node>(op, name);
op->setUpperNode(metaOp);
addProducer(metaOp, 1, {hiddenChannel, inChannel}, "wi");
addProducer(metaOp, 2, {hiddenChannel, inChannel}, "wo");
addProducer(metaOp, 3, {hiddenChannel, inChannel}, "wf");
......@@ -135,93 +143,4 @@ std::shared_ptr<Node> LSTM(const DimSize_t inChannel,
return metaOp;
}
std::shared_ptr<MetaOperator_Op> LSTM_Op(const DimSize_t seqLength)
{
// Construct micro-graph
auto input = Identity("");
auto hiddenState = Memorize(seqLength, "");
auto cellState = Memorize(seqLength, "");
auto add = Add("");
// Forget gate
auto forgetGateX = std::make_shared<Node>(std::make_shared<FC_Op>(), "");
input->addChild(forgetGateX, 0, 0);
auto forgetGateH = std::make_shared<Node>(std::make_shared<FC_Op>(), "");
hiddenState->addChild(forgetGateH, 1, 0);
auto forgetGate = Add("");
forgetGateX->addChild(forgetGate, 0, 0);
forgetGateH->addChild(forgetGate, 0, 1);
auto forgetGateAct = Sigmoid("");
auto forgetGateMul = Mul("");
forgetGate->addChild(forgetGateAct, 0, 0);
forgetGateAct->addChild(forgetGateMul, 0, 0);
forgetGateMul->addChild(add, 0, 0);
cellState->addChild(forgetGateMul, 1, 1);
// Input gate
auto inputGateX = std::make_shared<Node>(std::make_shared<FC_Op>(), "");
input->addChild(inputGateX, 0, 0);
auto inputGateH = std::make_shared<Node>(std::make_shared<FC_Op>(), "");
hiddenState->addChild(inputGateH, 1, 0);
auto inputGate = Add("");
inputGateX->addChild(inputGate, 0, 0);
inputGateH->addChild(inputGate, 0, 1);
auto inputGateAct = Sigmoid("");
auto inputGateMul = Mul("");
inputGate->addChild(inputGateAct, 0, 0);
inputGateAct->addChild(inputGateMul, 0, 0);
inputGateMul->addChild(add, 0, 1);
// Candidate for cell update
auto cellCandidateX = std::make_shared<Node>(std::make_shared<FC_Op>(), "");
input->addChild(cellCandidateX, 0, 0);
auto cellCandidateH = std::make_shared<Node>(std::make_shared<FC_Op>(), "");
hiddenState->addChild(cellCandidateH, 1, 0);
auto cellCandidate = Add("");
cellCandidateX->addChild(cellCandidate, 0, 0);
cellCandidateH->addChild(cellCandidate, 0, 1);
auto cellCandidateAct = Tanh("");
cellCandidate->addChild(cellCandidateAct, 0, 0);
cellCandidateAct->addChild(inputGateMul, 0, 1);
// Output gate
auto outputGateX = std::make_shared<Node>(std::make_shared<FC_Op>(), "");
input->addChild(outputGateX, 0, 0);
auto outputGateH = std::make_shared<Node>(std::make_shared<FC_Op>(), "");
hiddenState->addChild(outputGateH, 1, 0);
auto outputGate = Add("");
outputGateX->addChild(outputGate, 0, 0);
outputGateH->addChild(outputGate, 0, 1);
auto outputGateAct = Sigmoid("");
auto outputGateMul = Mul("");
outputGate->addChild(outputGateAct, 0, 0);
outputGateAct->addChild(outputGateMul, 0, 0);
// Updated cell state to help determine new hidden state
auto cellUpdatedAct = Tanh("");
add->addChild(cellUpdatedAct, 0, 0);
cellUpdatedAct->addChild(outputGateMul, 0, 1);
outputGateMul->addChild(hiddenState, 0, 0);
add->addChild(cellState, 0, 0);
std::shared_ptr<GraphView> microGraph = std::make_shared<GraphView>();
microGraph->add(input);
microGraph->add({hiddenState, cellState, add,
forgetGateX, forgetGateH, forgetGate, forgetGateAct, forgetGateMul,
inputGateX, inputGateH, inputGate, inputGateAct, inputGateMul,
cellCandidateX, cellCandidateH, cellCandidate, cellCandidateAct,
outputGateX, outputGateH, outputGate, outputGateAct, outputGateMul,
cellUpdatedAct}, false);
microGraph->setOrderedInputs({{input, 0},
{inputGateX, 1}, {outputGateX, 1}, {forgetGateX, 1}, {cellCandidateX, 1},
{inputGateH, 1}, {outputGateH, 1}, {forgetGateH, 1}, {cellCandidateH, 1},
{inputGateX, 2}, {outputGateX, 2}, {forgetGateX, 2}, {cellCandidateX, 2},
{inputGateH, 2}, {outputGateH, 2}, {forgetGateH, 2}, {cellCandidateH, 2},
{hiddenState, 1}, {cellState, 1}});
microGraph->setOrderedOutputs({{hiddenState, 0}, {cellState, 0}});
return std::make_shared<MetaOperator_Op>("LSTM", microGraph);
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment