Skip to content

Support for recurrent networks

Olivier BICHLER requested to merge memorize into dev

Add support for recurrent networks, with a Memorize Operator.

%%{init: {'flowchart': { 'curve': 'monotoneY'}, 'fontFamily': 'Verdana' } }%%
graph TD
    Op((MemorizeOp))
    In0[data_input] --> |"In[0]"| Op
    In1[data_input_init] --> |"In[1]"| Op
    Op --> |"Out[0]"| Out0[data_output]
    Op --> |"Out[1]"| Out1[data_output_rec]

See the unit test that implements the following network:

%%{init: {'flowchart': { 'curve': 'monotoneY'}, 'fontFamily': 'Verdana' } }%%
flowchart TB

Add_0("add1\n<sub><em>(Add#0)</em></sub>"):::rootCls
Memorize_0("mem1\n<sub><em>(Memorize#0)</em></sub>")
Add_1("add2\n<sub><em>(Add#1)</em></sub>")
Producer_2("bias\n<sub><em>(Producer#2)</em></sub>"):::producerCls
Producer_1("init\n<sub><em>(Producer#1)</em></sub>"):::producerCls
Producer_0("input\n<sub><em>(Producer#0)</em></sub>"):::producerCls
Add_0-->|"0&rarr;0"|Memorize_0
Memorize_0-->|"0&rarr;0"|Add_1
Memorize_0-->|"1&rarr;1"|Add_0
Producer_2-->|"0 [2, 3]&rarr;1"|Add_1
Producer_1-->|"0 [2, 3]&rarr;1"|Memorize_0
Producer_0-->|"0 [2, 3]&rarr;0"|Add_0
input0((in#0)):::inputCls--->|&rarr;1|Add_0
Add_1--->|"0&rarr;"|output0((out#0)):::outputCls
Memorize_0--->|"1&rarr;"|output1((out#1)):::outputCls
classDef inputCls fill:#afa
classDef outputCls fill:#ffa
classDef externalCls fill:#ccc
classDef producerCls fill:#ccf
classDef genericCls fill:#f9f9ff,stroke-width:1px,stroke-dasharray: 5 5
classDef metaCls stroke-width:5px
classDef rootCls stroke:#f00
classDef producerCls_rootCls stroke:#f00,fill:#ccf
classDef genericCls_rootCls stroke:#f00,fill:#f9f9ff,stroke-width:1px,stroke-dasharray: 5 5
classDef metaCls_rootCls stroke:#f00,stroke-width:5px

It will produce the following scheduling, with Memorize's attribute EndStep set to 3:

gantt
dateFormat x
axisFormat %Q ms

Producer_140736979873424 :0, 2
Memorize_140736979872064 :8, 16
Producer_140736979872544 :21, 22
Add_140736979871824 :26, 60
Add_140736979872784 :63, 78
Memorize_140736979872064 :83, 85
Producer_140736979872544 :88, 90
Add_140736979871824 :93, 107
Add_140736979872784 :110, 124
Memorize_140736979872064 :129, 132
Producer_140736979872544 :135, 136
Add_140736979871824 :139, 153
Add_140736979872784 :157, 171
Memorize_140736979872064 :174, 176
Producer_140736979872544 :179, 180
Add_140736979872784 :183, 197

What is in this MR?

  • Novel Memorize Operator;
  • Updated Scheduler to correctly handle this case;
  • Producer operator production consumption model slighly changed;
  • Update of forwardDims();
  • Recursive MetaOperatorMetaOperatorMetaOperatorMetaOperatorMetaOperator
  • Added Tanh and Sigmoid operators necessary for vanilla LSTM;
  • Added a new Pop operator that generate a sequence from a Producer;
  • Added getConnectedGraphView() to build a GraphView from connected nodes;
  • ProducerOp and GenericOp have a default implementation now, allowing them to be used for scheduling;
  • Improved various error messages useful for debugging (there is still work to do!);
  • Much improved display in GraphView and Scheduler:
    • Consistent nodes naming everywhere;
    • New GraphView::getRankedNodes() and GraphView::getRankedNodesName() methods;
    • Use the light-weight C++ {fmt} formatting library, from which the C++20 std::format is based on.
  • Invalid registrar now displays the value of the missing key!
  • Handle multi-output ONNX nodes.

Bonus

  • Ported MemoryManager from N2D2 to Aidge;
  • Added SequentialScheduler::generateMemory() that does the same as old N2D2::CPP_DeepNetExport::generateMemory(), but in a much more general way (only the specific handling of Concat operator, that added a lot of complexity in N2D2, is missing).

TODO

  • Improve forwardDims();
  • Provide a more comprehensive recurrent example (RNN or LSTM);
    • Test with actual values.
  • Import LSTM model with ONNX.

TODO later

I feel this MR may become too big and take too much time to integrate the following things:

  • Test and run LSTM ONNX import with actual values.

    • We don't have a simple model to test. Model from PulseAudition has still some operators missing...

      ...but it is correctly imported for now, give it a try: 😃

      import aidge_core
      import aidge_backend_cpu
      import aidge_onnx
      model = aidge_onnx.load_onnx("PulseAudition_LSTM.onnx", True)
      model.save("test", True, False)
  • Implement an ONNX import able to translate Loop using the proposed Memorize mechanism;

    • Normally, LSTM and GRU should be exported with the corresponding ONNX operators. That was not always the case in the past, but a LSTM or GRU exported as a Loop in ONNX is considered as an error in PyTorch and Tensorflow. We therefore have a lack of actual use case for now.
  • The Memorize Operator does not store intermediate outputs right now (this will be needed for learning).

    • To be seeing when basic learning mechanisms will be in place.
Edited by Olivier BICHLER

Merge request reports

Loading