Support for recurrent networks
Compare changes
Files
90+ 57
− 57
@@ -95,60 +95,60 @@ build:ubuntu_python:
Add support for recurrent networks, with a Memorize
Operator.
%%{init: {'flowchart': { 'curve': 'monotoneY'}, 'fontFamily': 'Verdana' } }%%
graph TD
Op((MemorizeOp))
In0[data_input] --> |"In[0]"| Op
In1[data_input_init] --> |"In[1]"| Op
Op --> |"Out[0]"| Out0[data_output]
Op --> |"Out[1]"| Out1[data_output_rec]
See the unit test that implements the following network:
%%{init: {'flowchart': { 'curve': 'monotoneY'}, 'fontFamily': 'Verdana' } }%%
flowchart TB
Add_0("add1\n<sub><em>(Add#0)</em></sub>"):::rootCls
Memorize_0("mem1\n<sub><em>(Memorize#0)</em></sub>")
Add_1("add2\n<sub><em>(Add#1)</em></sub>")
Producer_2("bias\n<sub><em>(Producer#2)</em></sub>"):::producerCls
Producer_1("init\n<sub><em>(Producer#1)</em></sub>"):::producerCls
Producer_0("input\n<sub><em>(Producer#0)</em></sub>"):::producerCls
Add_0-->|"0→0"|Memorize_0
Memorize_0-->|"0→0"|Add_1
Memorize_0-->|"1→1"|Add_0
Producer_2-->|"0 [2, 3]→1"|Add_1
Producer_1-->|"0 [2, 3]→1"|Memorize_0
Producer_0-->|"0 [2, 3]→0"|Add_0
input0((in#0)):::inputCls--->|→1|Add_0
Add_1--->|"0→"|output0((out#0)):::outputCls
Memorize_0--->|"1→"|output1((out#1)):::outputCls
classDef inputCls fill:#afa
classDef outputCls fill:#ffa
classDef externalCls fill:#ccc
classDef producerCls fill:#ccf
classDef genericCls fill:#f9f9ff,stroke-width:1px,stroke-dasharray: 5 5
classDef metaCls stroke-width:5px
classDef rootCls stroke:#f00
classDef producerCls_rootCls stroke:#f00,fill:#ccf
classDef genericCls_rootCls stroke:#f00,fill:#f9f9ff,stroke-width:1px,stroke-dasharray: 5 5
classDef metaCls_rootCls stroke:#f00,stroke-width:5px
It will produce the following scheduling, with Memorize
's attribute EndStep
set to 3:
gantt
dateFormat x
axisFormat %Q ms
Producer_140736979873424 :0, 2
Memorize_140736979872064 :8, 16
Producer_140736979872544 :21, 22
Add_140736979871824 :26, 60
Add_140736979872784 :63, 78
Memorize_140736979872064 :83, 85
Producer_140736979872544 :88, 90
Add_140736979871824 :93, 107
Add_140736979872784 :110, 124
Memorize_140736979872064 :129, 132
Producer_140736979872544 :135, 136
Add_140736979871824 :139, 153
Add_140736979872784 :157, 171
Memorize_140736979872064 :174, 176
Producer_140736979872544 :179, 180
Add_140736979872784 :183, 197
Memorize
Operator;Producer
operator production consumption model slighly changed;forwardDims()
;Tanh
and Sigmoid
operators necessary for vanilla LSTM;Pop
operator that generate a sequence from a Producer
;getConnectedGraphView()
to build a GraphView
from connected nodes;ProducerOp
and GenericOp
have a default implementation now, allowing them to be used for scheduling;GraphView
and Scheduler
:
GraphView::getRankedNodes()
and GraphView::getRankedNodesName()
methods;std::format
is based on.MemoryManager
from N2D2 to Aidge;SequentialScheduler::generateMemory()
that does the same as old N2D2::CPP_DeepNetExport::generateMemory()
, but in a much more general way (only the specific handling of Concat
operator, that added a lot of complexity in N2D2, is missing).forwardDims()
;I feel this MR may become too big and take too much time to integrate the following things:
Test and run LSTM ONNX import with actual values.
We don't have a simple model to test. Model from PulseAudition has still some operators missing...
...but it is correctly imported for now, give it a try:
import aidge_core
import aidge_backend_cpu
import aidge_onnx
model = aidge_onnx.load_onnx("PulseAudition_LSTM.onnx", True)
model.save("test", True, False)
Implement an ONNX import able to translate Loop
using the proposed Memorize
mechanism;
The Memorize
Operator does not store intermediate outputs right now (this will be needed for learning).
Copyright © Eclipse Foundation, Inc. All Rights Reserved. Privacy Policy | Terms of Use | Copyright Agent