Skip to content
Snippets Groups Projects
Commit 387f423b authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added LSTM example (not tested yet with actual values)

parent 0c96654a
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!37Support for recurrent networks
Pipeline #38760 failed
......@@ -24,7 +24,8 @@
using namespace Aidge;
TEST_CASE("[cpu/operator] MetaOperator/PaddedConv(forward)", "[MetaOperator][PaddedConv][CPU]") {
TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
SECTION("PaddedConv(forward)") {
std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(
Array4D<double, 4, 3, 3, 3>{{{{{6.20986394e-01, 1.19775136e-03, 7.22876095e-02},
{1.16492919e-01, 8.21634093e-02, 1.17413265e-01},
......@@ -187,4 +188,37 @@ TEST_CASE("[cpu/operator] MetaOperator/PaddedConv(forward)", "[MetaOperator][Pad
std::shared_ptr<Node> myPaddedConv =
PaddedConv(3, 4, {3, 3}, "myPaddedConv", {1, 1}, {1, 1, 1, 1});
}
SECTION("LTSM(forward)") {
auto myLSTM = LTSM(32, 64, 16, "ltsm");
auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator());
auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraph();
microGraph->save("lstm", false, false);
REQUIRE(myLSTM->nbInputs() == 3);
REQUIRE(myLSTM->nbData() == 3);
REQUIRE(myLSTM->nbOutputs() == 2);
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
Array1D<float, 32>{{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}});
std::shared_ptr<Tensor> myInit = std::make_shared<Tensor>(
Array2D<float, 1, 64>{{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
op->associateInput(0, myInput);
op->associateInput(1, myInit);
op->associateInput(2, myInit);
op->computeOutputDims();
REQUIRE(op->outputDimsForwarded());
microGraph->save("lstm_dims", false, false);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
op->forward();
auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraphScheduler();
microGraphScheduler->saveSchedulingDiagram("lstm_scheduling");
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment