Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
aidge_interop_torch
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Eclipse Projects
aidge
aidge_interop_torch
Commits
d0cf58bc
Commit
d0cf58bc
authored
1 month ago
by
Cyril Moineau
Browse files
Options
Downloads
Patches
Plain Diff
Fix multiple warnings with torch.
parent
56bccf0a
No related branches found
No related tags found
1 merge request
!4
Fix module
Pipeline
#73968
failed
1 month ago
Stage: static_analysis
Stage: build
Stage: test
Stage: coverage
Changes
1
Pipelines
2
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
aidge_interop_torch/utils.py
+36
-20
36 additions, 20 deletions
aidge_interop_torch/utils.py
with
36 additions
and
20 deletions
aidge_interop_torch/utils.py
+
36
−
20
View file @
d0cf58bc
...
...
@@ -10,7 +10,8 @@ import aidge_learning
from
onnxsim
import
simplify
from
typing
import
Union
from
pathlib
import
Path
import
warnings
def
convert_tensor
(
tensor
):
"""
Convert a torch tensor to :py:class:`aidge_core.Tensor` and vice versa.
...
...
@@ -259,16 +260,18 @@ class ContextNoBatchNormFuse:
"""
cpt
=
0
for
module
in
self
.
model
.
modules
():
if
isinstance
(
module
,
torch
.
nn
.
modules
.
batchnorm
.
_BatchNorm
):
# Restore Batchnorm forward
# torch.nn.modules.batchnorm._BatchNorm.forward
module
.
forward
=
self
.
forwards
[
cpt
]
cpt
+=
1
pass
# if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
# # Restore Batchnorm forward
# # torch.nn.modules.batchnorm._BatchNorm.forward
# module.forward = self.forwards[cpt]
# cpt += 1
def
wrap
(
torch_model
:
torch
.
nn
.
Module
,
input_size
:
Union
[
list
,
tuple
],
opset_version
:
int
=
11
,
opset_version
:
int
=
18
,
save_onnx_model
:
bool
=
False
,
in_names
:
list
=
None
,
out_names
:
list
=
None
,
verbose
:
bool
=
False
)
->
AidgeModule
:
...
...
@@ -281,6 +284,8 @@ def wrap(torch_model: torch.nn.Module,
:type input_size: ``list``
:param opset_version: Opset version used to generate the intermediate ONNX file, default=11
:type opset_version: int, optional
:param save_onnx_model: If True intermediate onnx files are saved, default=False
:type save_onnx_model: bool, optional
:param in_names: Specify specific names for the network inputs
:type in_names: list, optional
:param out_names: Specify specific names for the network outputs
...
...
@@ -290,7 +295,7 @@ def wrap(torch_model: torch.nn.Module,
:return: A custom ``torch.nn.Module`` which embed a :py:class:`aidge_core.GraphView`.
:rtype: :py:class:`aidge_interop_torch.AidgeModule`
"""
raw_model_path
=
f
'
./
{
torch_model
.
__class__
.
__name__
}
_raw.onnx
'
raw_model_path
=
Path
(
f
'
./
{
torch_model
.
__class__
.
__name__
}
_raw.onnx
'
)
model_path
=
f
'
./
{
torch_model
.
__class__
.
__name__
}
.onnx
'
print
(
"
Exporting torch module to ONNX ...
"
)
...
...
@@ -302,30 +307,41 @@ def wrap(torch_model: torch.nn.Module,
dummy_in
=
torch
.
zeros
(
input_size
).
to
(
torch_device
)
# Setting model to
training
# Setting model to
eval
# important to keep information with BatchNorm
torch_model
.
train
()
# removing spam warning from pytorch
warnings
.
filterwarnings
(
"
ignore
"
,
message
=
"
.*Constant folding - Only steps=1 can be constant folded.*
"
)
warnings
.
filterwarnings
(
"
ignore
"
,
message
=
"
ONNX export mode is set to TrainingMode.EVAL, but operator
'
batch_norm
'
is set to train=True. Exporting with train=True.
"
)
# Note : To keep batchnorm we export model in train mode.
# However we cannot freeze batchnorm stats in pytorch < 12 (see : https://github.com/pytorch/pytorch/issues/75252).
# And even in > 12 when stats freezed the ONNX graph drastically changes ...
# To deal with this issue we use a context which change the forward behavior of batchnorm to protect stats.
with
ContextNoBatchNormFuse
(
torch_model
)
as
ctx
:
torch
.
onnx
.
export
(
torch_model
,
dummy_in
,
raw_model_path
,
verbose
=
verbose
,
input_names
=
in_names
,
output_names
=
out_names
,
export_params
=
True
,
opset_version
=
opset_version
,
do_constant_folding
=
False
)
# with ContextNoBatchNormFuse(torch_model) as ctx:
torch
.
onnx
.
export
(
torch_model
,
dummy_in
,
raw_model_path
,
verbose
=
verbose
,
input_names
=
in_names
,
output_names
=
out_names
,
export_params
=
True
,
opset_version
=
opset_version
,
do_constant_folding
=
False
)
print
(
"
Simplifying the ONNX model ...
"
)
onnx_model
=
onnx
.
load
(
raw_model_path
)
raw_model_path
.
unlink
()
model_simp
,
check
=
simplify
(
onnx_model
)
assert
check
,
"
Simplified ONNX model could not be validated
"
onnx
.
save
(
model_simp
,
model_path
)
if
save_onnx_model
:
onnx
.
save
(
model_simp
,
model_path
)
aidge_model
=
aidge_onnx
.
onnx_import
.
convert_onnx_to_aidge
(
model_simp
)
aidge_core
.
remove_flatten
(
aidge_model
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment