Skip to content
Snippets Groups Projects
Commit 1cdefe55 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed Producers for LSTM

parent 47aa7d13
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!77Support for recurrent networks
Pipeline #38999 failed
......@@ -140,8 +140,8 @@ public:
void setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs);
void setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs);
inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedInputs() { return mInputNodes; };
inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedOutputs() { return mOutputNodes; };
inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedInputs() const { return mInputNodes; };
inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedOutputs() const { return mOutputNodes; };
/**
* @brief List outside data input connections of the GraphView.
......
......@@ -147,6 +147,7 @@ inline std::shared_ptr<Node> PaddedMaxPooling(
inline std::shared_ptr<Node> LSTM(DimSize_t in_channels,
DimSize_t hidden_channels,
DimSize_t seq_length,
bool noBias = false,
const std::string& name = "")
{
// Construct micro-graph
......@@ -156,9 +157,9 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels,
auto add = Add(2, (!name.empty()) ? name + "_add" : "");
// Forget gate
auto forgetGateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_forgetGateX" : "");
auto forgetGateX = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_forgetGateX" : "");
input->addChild(forgetGateX, 0, 0);
auto forgetGateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_forgetGateH" : "");
auto forgetGateH = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_forgetGateH" : "");
hiddenState->addChild(forgetGateH, 1, 0);
auto forgetGate = Add(2, (!name.empty()) ? name + "_forgetGate" : "");
forgetGateX->addChild(forgetGate, 0, 0);
......@@ -171,9 +172,9 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels,
cellState->addChild(forgetGateMul, 1, 1);
// Input gate
auto inputGateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_inputGateX" : "");
auto inputGateX = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_inputGateX" : "");
input->addChild(inputGateX, 0, 0);
auto inputGateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_inputGateH" : "");
auto inputGateH = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_inputGateH" : "");
hiddenState->addChild(inputGateH, 1, 0);
auto inputGate = Add(2, (!name.empty()) ? name + "_inputGate" : "");
inputGateX->addChild(inputGate, 0, 0);
......@@ -185,9 +186,9 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels,
inputGateMul->addChild(add, 0, 1);
// Candidate for cell update
auto cellCandidateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_cellCandidateX" : "");
auto cellCandidateX = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_cellCandidateX" : "");
input->addChild(cellCandidateX, 0, 0);
auto cellCandidateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_cellCandidateH" : "");
auto cellCandidateH = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_cellCandidateH" : "");
hiddenState->addChild(cellCandidateH, 1, 0);
auto cellCandidate = Add(2, (!name.empty()) ? name + "_cellCandidate" : "");
cellCandidateX->addChild(cellCandidate, 0, 0);
......@@ -197,9 +198,9 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels,
cellCandidateAct->addChild(inputGateMul, 0, 1);
// Output gate
auto outputGateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_outputGateX" : "");
auto outputGateX = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_outputGateX" : "");
input->addChild(outputGateX, 0, 0);
auto outputGateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_outputGateH" : "");
auto outputGateH = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_outputGateH" : "");
hiddenState->addChild(outputGateH, 1, 0);
auto outputGate = Add(2, (!name.empty()) ? name + "_outputGate" : "");
outputGateX->addChild(outputGate, 0, 0);
......@@ -223,9 +224,33 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels,
inputGateX, inputGateH, inputGate, inputGateAct, inputGateMul,
cellCandidateX, cellCandidateH, cellCandidate, cellCandidateAct,
outputGateX, outputGateH, outputGate, outputGateAct, outputGateMul,
cellUpdatedAct});
cellUpdatedAct}, false);
return MetaOperator("LSTM", microGraph, name);
microGraph->setOrderedInputs({{input, 0},
{inputGateX, 1}, {outputGateX, 1}, {forgetGateX, 1}, {cellCandidateX, 1},
{inputGateH, 1}, {outputGateH, 1}, {forgetGateH, 1}, {cellCandidateH, 1},
{inputGateX, 2}, {outputGateX, 2}, {forgetGateX, 2}, {cellCandidateX, 2},
{inputGateH, 2}, {outputGateH, 2}, {forgetGateH, 2}, {cellCandidateH, 2},
{hiddenState, 1}, {cellState, 1}});
auto metaOp = MetaOperator("LSTM", microGraph, name);
addProducer(metaOp, 1, {hidden_channels, in_channels}, "wi");
addProducer(metaOp, 2, {hidden_channels, in_channels}, "wo");
addProducer(metaOp, 3, {hidden_channels, in_channels}, "wf");
addProducer(metaOp, 4, {hidden_channels, in_channels}, "wc");
addProducer(metaOp, 5, {hidden_channels, hidden_channels}, "ri");
addProducer(metaOp, 6, {hidden_channels, hidden_channels}, "ro");
addProducer(metaOp, 7, {hidden_channels, hidden_channels}, "rf");
addProducer(metaOp, 8, {hidden_channels, hidden_channels}, "rc");
addProducer(metaOp, 9, {(noBias ? 0 : hidden_channels)}, "wbi");
addProducer(metaOp, 10, {(noBias ? 0 : hidden_channels)}, "wbo");
addProducer(metaOp, 11, {(noBias ? 0 : hidden_channels)}, "wbf");
addProducer(metaOp, 12, {(noBias ? 0 : hidden_channels)}, "wbc");
addProducer(metaOp, 13, {(noBias ? 0 : hidden_channels)}, "rbi");
addProducer(metaOp, 14, {(noBias ? 0 : hidden_channels)}, "rbo");
addProducer(metaOp, 15, {(noBias ? 0 : hidden_channels)}, "rbf");
addProducer(metaOp, 16, {(noBias ? 0 : hidden_channels)}, "rbc");
return metaOp;
}
} // namespace Aidge
......
......@@ -112,12 +112,14 @@ void declare_LSTMOp(py::module &m) {
m.def("LSTM", [](DimSize_t in_channels,
DimSize_t hidden_channels,
DimSize_t seq_length,
bool nobias,
const std::string& name)
{
return LSTM(in_channels, hidden_channels, seq_length, name);
return LSTM(in_channels, hidden_channels, seq_length, nobias, name);
}, py::arg("in_channels"),
py::arg("hidden_channels"),
py::arg("seq_length"),
py::arg("nobias") = false,
py::arg("name") = "");
}
......
......@@ -53,13 +53,13 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") {
}
SECTION("LSTM") {
auto myLSTM = LSTM(32, 64, 16, "ltsm");
auto myLSTM = LSTM(32, 64, 16, true, "ltsm");
auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator());
auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraph();
microGraph->save("lstm", false, false);
REQUIRE(myLSTM->nbInputs() == 3);
REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
REQUIRE(myLSTM->nbData() == 3);
REQUIRE(myLSTM->nbOutputs() == 2);
......@@ -69,12 +69,12 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") {
myInit->resize({1, 64});
op->associateInput(0, myInput);
op->associateInput(1, myInit);
op->associateInput(2, myInit);
op->associateInput(17, myInit);
op->associateInput(18, myInit);
op->computeOutputDims();
microGraph->save("lstm_dims", true, true);
REQUIRE(op->outputDimsForwarded());
microGraph->save("lstm_dims", false, false);
//op->updateConsummerProducer(); // require implementation
//auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraphScheduler();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment