diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp index 271a1e2f9860d92f840916f6b2e396993b0bea39..3286cbb122a8b68ca2e51f4685060f408826b865 100644 --- a/unit_tests/operator/Test_MetaOperator.cpp +++ b/unit_tests/operator/Test_MetaOperator.cpp @@ -9,6 +9,7 @@ * ********************************************************************************/ +#include <aidge/operator/Operator.hpp> #include <catch2/catch_test_macros.hpp> #include <cmath> #include <cstdlib> @@ -192,21 +193,50 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") { PaddedConv(3, 4, {3, 3}, "myPaddedConv", {1, 1}, {1, 1, 1, 1}); } SECTION("LSTM(forward)") { + + // The `Pop` operator will, at each execution, pop 32 elements. auto pop = Pop(); - auto myLSTM = LSTM(32, 64, 0, true, "ltsm"); + + auto myLSTM = LSTM(/* in_channels */ 32, + /* hidden_channels */ 64, + /* seq_length */ 0, + /* no_bias*/ true, + /* name */ "ltsm"); + auto op = std::dynamic_pointer_cast<MetaOperator_Op>(myLSTM->getOperator()); auto microGraph = op->getMicroGraph(); microGraph->save("lstm", false, true); - REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8); + // LSTM has a total of 19 inputs : + // 8 weights + auto constexpr nbWeightInputs = 8; + // 8 bias + auto constexpr nbBiasInputs = 8; + // 2 init inputs for the cell state and hidden state, + // that corresponds to the init inputs of the memory nodes + // in our implementation + auto constexpr nbMemorizeInputs = 2; + // 1 'real' input + + REQUIRE(myLSTM->nbInputs() == nbWeightInputs + nbBiasInputs + nbMemorizeInputs + 1); REQUIRE(myLSTM->inputCategory(0) == InputCategory::Data); for (size_t i = 1; i < 9; ++i) { + // Inputs 1 to 8 are weights, which are mandatory parameters REQUIRE(myLSTM->inputCategory(i) == InputCategory::Param); } for (size_t i = 9; i < 17; ++i) { + // Inputs 9 to 16 are bias, which are optional parameters, + // controlled by the `no_bias` option in the LSTM() constructor REQUIRE(myLSTM->inputCategory(i) == InputCategory::OptionalParam); } + + // The initial input of a Memorize_Op is an InputCategory::Param + REQUIRE(myLSTM->inputCategory(17) == InputCategory::Param); + REQUIRE(myLSTM->inputCategory(18) == InputCategory::Param); + + // LSTM has two outpus: + // The hidden state and the cell state REQUIRE(myLSTM->nbOutputs() == 2); std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>( @@ -218,6 +248,7 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") { std::shared_ptr<Tensor> myInitR = std::make_shared<Tensor>( Array2D<float, 64, 64>{}); + // Connect everything pop->addChild(myLSTM, 0, 0); pop->getOperator()->associateInput(0, myInput); op->associateInput(17, myInit); @@ -249,8 +280,27 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") { auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraphScheduler(); microGraphScheduler->saveSchedulingDiagram("lstm_scheduling"); - REQUIRE(op->getNbConsumedData(0).data == 512); - REQUIRE(op->getNbConsumedData(1).data == 32768); + + // getNbComsumedData() returns the total amount of data consumed since + // the start of the producer consumer model run. + // It is therefore logical that, at the end of the run, all the input + // has been consumed (32*16 = 512). The subtlety is that it is consumed + // in several stages. + // + // The Pop operator will unpack 32 elements each time it runs. As the + // model is executed, getNbComsumedData() will therefore go from 0 to + // 32, to 64, and so on. This is true as long as the consumer, which is + // connected to Pop's output, can consume something. This is the case + // if, for each of its inputs, there is enough data to consume. + + + // Data consumed by the input #0 is all of the input tensor, = 16 * 32 = 512 + REQUIRE(op->getNbConsumedData(/* inputIdx */ 0).data == 512); + + // Data consumed by the input #1 is (32 * 64) * 16: + // the data of the `myInit` times the number of iterations of Pop. + REQUIRE(op->getNbConsumedData(/* inputIdx */ 1).data == 32768); + REQUIRE(op->getNbProducedData(0).data == 34816); REQUIRE(op->getNbProducedData(1).data == 34816); REQUIRE(microGraphScheduler->getStaticScheduling(0).size() == 26); @@ -519,4 +569,4 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") { REQUIRE(approxEq<float>(*(op->getOutput(0)), *myHiddenState)); } -} \ No newline at end of file +}