Skip to content
Snippets Groups Projects
Commit 1cdefe55 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed Producers for LSTM

parent 47aa7d13
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!77Support for recurrent networks
Pipeline #38999 failed
...@@ -140,8 +140,8 @@ public: ...@@ -140,8 +140,8 @@ public:
void setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs); void setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs);
void setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs); void setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs);
inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedInputs() { return mInputNodes; }; inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedInputs() const { return mInputNodes; };
inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedOutputs() { return mOutputNodes; }; inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedOutputs() const { return mOutputNodes; };
/** /**
* @brief List outside data input connections of the GraphView. * @brief List outside data input connections of the GraphView.
......
...@@ -147,6 +147,7 @@ inline std::shared_ptr<Node> PaddedMaxPooling( ...@@ -147,6 +147,7 @@ inline std::shared_ptr<Node> PaddedMaxPooling(
inline std::shared_ptr<Node> LSTM(DimSize_t in_channels, inline std::shared_ptr<Node> LSTM(DimSize_t in_channels,
DimSize_t hidden_channels, DimSize_t hidden_channels,
DimSize_t seq_length, DimSize_t seq_length,
bool noBias = false,
const std::string& name = "") const std::string& name = "")
{ {
// Construct micro-graph // Construct micro-graph
...@@ -156,9 +157,9 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels, ...@@ -156,9 +157,9 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels,
auto add = Add(2, (!name.empty()) ? name + "_add" : ""); auto add = Add(2, (!name.empty()) ? name + "_add" : "");
// Forget gate // Forget gate
auto forgetGateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_forgetGateX" : ""); auto forgetGateX = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_forgetGateX" : "");
input->addChild(forgetGateX, 0, 0); input->addChild(forgetGateX, 0, 0);
auto forgetGateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_forgetGateH" : ""); auto forgetGateH = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_forgetGateH" : "");
hiddenState->addChild(forgetGateH, 1, 0); hiddenState->addChild(forgetGateH, 1, 0);
auto forgetGate = Add(2, (!name.empty()) ? name + "_forgetGate" : ""); auto forgetGate = Add(2, (!name.empty()) ? name + "_forgetGate" : "");
forgetGateX->addChild(forgetGate, 0, 0); forgetGateX->addChild(forgetGate, 0, 0);
...@@ -171,9 +172,9 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels, ...@@ -171,9 +172,9 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels,
cellState->addChild(forgetGateMul, 1, 1); cellState->addChild(forgetGateMul, 1, 1);
// Input gate // Input gate
auto inputGateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_inputGateX" : ""); auto inputGateX = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_inputGateX" : "");
input->addChild(inputGateX, 0, 0); input->addChild(inputGateX, 0, 0);
auto inputGateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_inputGateH" : ""); auto inputGateH = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_inputGateH" : "");
hiddenState->addChild(inputGateH, 1, 0); hiddenState->addChild(inputGateH, 1, 0);
auto inputGate = Add(2, (!name.empty()) ? name + "_inputGate" : ""); auto inputGate = Add(2, (!name.empty()) ? name + "_inputGate" : "");
inputGateX->addChild(inputGate, 0, 0); inputGateX->addChild(inputGate, 0, 0);
...@@ -185,9 +186,9 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels, ...@@ -185,9 +186,9 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels,
inputGateMul->addChild(add, 0, 1); inputGateMul->addChild(add, 0, 1);
// Candidate for cell update // Candidate for cell update
auto cellCandidateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_cellCandidateX" : ""); auto cellCandidateX = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_cellCandidateX" : "");
input->addChild(cellCandidateX, 0, 0); input->addChild(cellCandidateX, 0, 0);
auto cellCandidateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_cellCandidateH" : ""); auto cellCandidateH = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_cellCandidateH" : "");
hiddenState->addChild(cellCandidateH, 1, 0); hiddenState->addChild(cellCandidateH, 1, 0);
auto cellCandidate = Add(2, (!name.empty()) ? name + "_cellCandidate" : ""); auto cellCandidate = Add(2, (!name.empty()) ? name + "_cellCandidate" : "");
cellCandidateX->addChild(cellCandidate, 0, 0); cellCandidateX->addChild(cellCandidate, 0, 0);
...@@ -197,9 +198,9 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels, ...@@ -197,9 +198,9 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels,
cellCandidateAct->addChild(inputGateMul, 0, 1); cellCandidateAct->addChild(inputGateMul, 0, 1);
// Output gate // Output gate
auto outputGateX = FC(in_channels, hidden_channels, false, (!name.empty()) ? name + "_outputGateX" : ""); auto outputGateX = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_outputGateX" : "");
input->addChild(outputGateX, 0, 0); input->addChild(outputGateX, 0, 0);
auto outputGateH = FC(hidden_channels, hidden_channels, false, (!name.empty()) ? name + "_outputGateH" : ""); auto outputGateH = std::make_shared<Node>(std::make_shared<FC_Op>(hidden_channels, noBias), (!name.empty()) ? name + "_outputGateH" : "");
hiddenState->addChild(outputGateH, 1, 0); hiddenState->addChild(outputGateH, 1, 0);
auto outputGate = Add(2, (!name.empty()) ? name + "_outputGate" : ""); auto outputGate = Add(2, (!name.empty()) ? name + "_outputGate" : "");
outputGateX->addChild(outputGate, 0, 0); outputGateX->addChild(outputGate, 0, 0);
...@@ -223,9 +224,33 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels, ...@@ -223,9 +224,33 @@ inline std::shared_ptr<Node> LSTM(DimSize_t in_channels,
inputGateX, inputGateH, inputGate, inputGateAct, inputGateMul, inputGateX, inputGateH, inputGate, inputGateAct, inputGateMul,
cellCandidateX, cellCandidateH, cellCandidate, cellCandidateAct, cellCandidateX, cellCandidateH, cellCandidate, cellCandidateAct,
outputGateX, outputGateH, outputGate, outputGateAct, outputGateMul, outputGateX, outputGateH, outputGate, outputGateAct, outputGateMul,
cellUpdatedAct}); cellUpdatedAct}, false);
return MetaOperator("LSTM", microGraph, name); microGraph->setOrderedInputs({{input, 0},
{inputGateX, 1}, {outputGateX, 1}, {forgetGateX, 1}, {cellCandidateX, 1},
{inputGateH, 1}, {outputGateH, 1}, {forgetGateH, 1}, {cellCandidateH, 1},
{inputGateX, 2}, {outputGateX, 2}, {forgetGateX, 2}, {cellCandidateX, 2},
{inputGateH, 2}, {outputGateH, 2}, {forgetGateH, 2}, {cellCandidateH, 2},
{hiddenState, 1}, {cellState, 1}});
auto metaOp = MetaOperator("LSTM", microGraph, name);
addProducer(metaOp, 1, {hidden_channels, in_channels}, "wi");
addProducer(metaOp, 2, {hidden_channels, in_channels}, "wo");
addProducer(metaOp, 3, {hidden_channels, in_channels}, "wf");
addProducer(metaOp, 4, {hidden_channels, in_channels}, "wc");
addProducer(metaOp, 5, {hidden_channels, hidden_channels}, "ri");
addProducer(metaOp, 6, {hidden_channels, hidden_channels}, "ro");
addProducer(metaOp, 7, {hidden_channels, hidden_channels}, "rf");
addProducer(metaOp, 8, {hidden_channels, hidden_channels}, "rc");
addProducer(metaOp, 9, {(noBias ? 0 : hidden_channels)}, "wbi");
addProducer(metaOp, 10, {(noBias ? 0 : hidden_channels)}, "wbo");
addProducer(metaOp, 11, {(noBias ? 0 : hidden_channels)}, "wbf");
addProducer(metaOp, 12, {(noBias ? 0 : hidden_channels)}, "wbc");
addProducer(metaOp, 13, {(noBias ? 0 : hidden_channels)}, "rbi");
addProducer(metaOp, 14, {(noBias ? 0 : hidden_channels)}, "rbo");
addProducer(metaOp, 15, {(noBias ? 0 : hidden_channels)}, "rbf");
addProducer(metaOp, 16, {(noBias ? 0 : hidden_channels)}, "rbc");
return metaOp;
} }
} // namespace Aidge } // namespace Aidge
......
...@@ -112,12 +112,14 @@ void declare_LSTMOp(py::module &m) { ...@@ -112,12 +112,14 @@ void declare_LSTMOp(py::module &m) {
m.def("LSTM", [](DimSize_t in_channels, m.def("LSTM", [](DimSize_t in_channels,
DimSize_t hidden_channels, DimSize_t hidden_channels,
DimSize_t seq_length, DimSize_t seq_length,
bool nobias,
const std::string& name) const std::string& name)
{ {
return LSTM(in_channels, hidden_channels, seq_length, name); return LSTM(in_channels, hidden_channels, seq_length, nobias, name);
}, py::arg("in_channels"), }, py::arg("in_channels"),
py::arg("hidden_channels"), py::arg("hidden_channels"),
py::arg("seq_length"), py::arg("seq_length"),
py::arg("nobias") = false,
py::arg("name") = ""); py::arg("name") = "");
} }
......
...@@ -53,13 +53,13 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") { ...@@ -53,13 +53,13 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") {
} }
SECTION("LSTM") { SECTION("LSTM") {
auto myLSTM = LSTM(32, 64, 16, "ltsm"); auto myLSTM = LSTM(32, 64, 16, true, "ltsm");
auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator()); auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator());
auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraph(); auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraph();
microGraph->save("lstm", false, false); microGraph->save("lstm", false, false);
REQUIRE(myLSTM->nbInputs() == 3); REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
REQUIRE(myLSTM->nbData() == 3); REQUIRE(myLSTM->nbData() == 3);
REQUIRE(myLSTM->nbOutputs() == 2); REQUIRE(myLSTM->nbOutputs() == 2);
...@@ -69,12 +69,12 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") { ...@@ -69,12 +69,12 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") {
myInit->resize({1, 64}); myInit->resize({1, 64});
op->associateInput(0, myInput); op->associateInput(0, myInput);
op->associateInput(1, myInit); op->associateInput(17, myInit);
op->associateInput(2, myInit); op->associateInput(18, myInit);
op->computeOutputDims(); op->computeOutputDims();
microGraph->save("lstm_dims", true, true);
REQUIRE(op->outputDimsForwarded()); REQUIRE(op->outputDimsForwarded());
microGraph->save("lstm_dims", false, false);
//op->updateConsummerProducer(); // require implementation //op->updateConsummerProducer(); // require implementation
//auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraphScheduler(); //auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraphScheduler();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment