From e5cc0ce2639b85ec6cad3dfa0a4b67444fc8d812 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Tue, 13 Feb 2024 17:47:18 +0100 Subject: [PATCH] First working version of LSTM inference --- unit_tests/operator/Test_MetaOperator.cpp | 59 ++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp index 2b63e1e1..4f74c5ce 100644 --- a/unit_tests/operator/Test_MetaOperator.cpp +++ b/unit_tests/operator/Test_MetaOperator.cpp @@ -191,7 +191,7 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") { } SECTION("LSTM(forward)") { - auto myLSTM = LSTM(32, 64, 16, "ltsm"); + auto myLSTM = LSTM(32, 64, 16, true, "ltsm"); auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator()); auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraph(); @@ -221,4 +221,61 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") { auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraphScheduler(); microGraphScheduler->saveSchedulingDiagram("lstm_scheduling"); } + + SECTION("LSTM(forward_values)") { + auto myLSTM = LSTM(2, 3, 1, true, "ltsm"); + auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator()); + + auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraph(); + microGraph->save("lstm", false, false); + + REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8); + REQUIRE(myLSTM->nbData() == 3); + REQUIRE(myLSTM->nbOutputs() == 2); + + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>( + Array2D<float, 3, 2>{{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}}); + std::shared_ptr<Tensor> myInit = std::make_shared<Tensor>( + Array2D<float, 3, 3>{{{0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}}}); + std::shared_ptr<Tensor> myInitW = std::make_shared<Tensor>( + Array2D<float, 3, 2>{{{0.1, 0.1}, {0.1, 0.1}, {0.1, 0.1}}}); + std::shared_ptr<Tensor> myInitR = std::make_shared<Tensor>( + Array2D<float, 3, 3>{{{0.1, 0.1, 0.1}, {0.1, 0.1, 0.1}, {0.1, 0.1, 0.1}}}); + + op->associateInput(0, myInput); + op->associateInput(17, myInit); + op->associateInput(18, myInit); + + op->computeOutputDims(); + REQUIRE(op->outputDimsForwarded()); + microGraph->save("lstm_values_dims", false, true); + + // Weights X + op->associateInput(1, myInitW); + op->associateInput(2, myInitW); + op->associateInput(3, myInitW); + op->associateInput(4, myInitW); + // Weights H + op->associateInput(5, myInitR); + op->associateInput(6, myInitR); + op->associateInput(7, myInitR); + op->associateInput(8, myInitR); + + op->setDataType(DataType::Float32); + op->setBackend("cpu"); + + std::shared_ptr<Tensor> myHiddenState = std::make_shared<Tensor>( + Array2D<float, 3, 3>{{{0.0952412, 0.0952412, 0.0952412}, + {0.25606447, 0.25606447, 0.25606447}, + {0.40323776, 0.40323776, 0.40323776}}}); + + op->forward(); + auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraphScheduler(); + microGraphScheduler->saveSchedulingDiagram("lstm_values_scheduling"); + + op->getOutput(1)->print(); + myHiddenState->print(); + + REQUIRE(approxEq<float>(*(op->getOutput(1)), *myHiddenState)); + } } \ No newline at end of file -- GitLab