Skip to content
Snippets Groups Projects
Commit e5cc0ce2 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

First working version of LSTM inference

parent 0149e9c1
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!37Support for recurrent networks
Pipeline #39005 failed
...@@ -191,7 +191,7 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") { ...@@ -191,7 +191,7 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
} }
SECTION("LSTM(forward)") { SECTION("LSTM(forward)") {
auto myLSTM = LSTM(32, 64, 16, "ltsm"); auto myLSTM = LSTM(32, 64, 16, true, "ltsm");
auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator()); auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator());
auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraph(); auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraph();
...@@ -221,4 +221,61 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") { ...@@ -221,4 +221,61 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraphScheduler(); auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraphScheduler();
microGraphScheduler->saveSchedulingDiagram("lstm_scheduling"); microGraphScheduler->saveSchedulingDiagram("lstm_scheduling");
} }
SECTION("LSTM(forward_values)") {
auto myLSTM = LSTM(2, 3, 1, true, "ltsm");
auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator());
auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraph();
microGraph->save("lstm", false, false);
REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
REQUIRE(myLSTM->nbData() == 3);
REQUIRE(myLSTM->nbOutputs() == 2);
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
Array2D<float, 3, 2>{{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}});
std::shared_ptr<Tensor> myInit = std::make_shared<Tensor>(
Array2D<float, 3, 3>{{{0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}}});
std::shared_ptr<Tensor> myInitW = std::make_shared<Tensor>(
Array2D<float, 3, 2>{{{0.1, 0.1}, {0.1, 0.1}, {0.1, 0.1}}});
std::shared_ptr<Tensor> myInitR = std::make_shared<Tensor>(
Array2D<float, 3, 3>{{{0.1, 0.1, 0.1}, {0.1, 0.1, 0.1}, {0.1, 0.1, 0.1}}});
op->associateInput(0, myInput);
op->associateInput(17, myInit);
op->associateInput(18, myInit);
op->computeOutputDims();
REQUIRE(op->outputDimsForwarded());
microGraph->save("lstm_values_dims", false, true);
// Weights X
op->associateInput(1, myInitW);
op->associateInput(2, myInitW);
op->associateInput(3, myInitW);
op->associateInput(4, myInitW);
// Weights H
op->associateInput(5, myInitR);
op->associateInput(6, myInitR);
op->associateInput(7, myInitR);
op->associateInput(8, myInitR);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
std::shared_ptr<Tensor> myHiddenState = std::make_shared<Tensor>(
Array2D<float, 3, 3>{{{0.0952412, 0.0952412, 0.0952412},
{0.25606447, 0.25606447, 0.25606447},
{0.40323776, 0.40323776, 0.40323776}}});
op->forward();
auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraphScheduler();
microGraphScheduler->saveSchedulingDiagram("lstm_values_scheduling");
op->getOutput(1)->print();
myHiddenState->print();
REQUIRE(approxEq<float>(*(op->getOutput(1)), *myHiddenState));
}
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment