Skip to content
Snippets Groups Projects
Commit d0cf58bc authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Fix multiple warnings with torch.

parent 56bccf0a
No related branches found
No related tags found
1 merge request!4Fix module
Pipeline #73968 failed
......@@ -10,7 +10,8 @@ import aidge_learning
from onnxsim import simplify
from typing import Union
from pathlib import Path
import warnings
def convert_tensor(tensor):
"""Convert a torch tensor to :py:class:`aidge_core.Tensor` and vice versa.
......@@ -259,16 +260,18 @@ class ContextNoBatchNormFuse:
"""
cpt = 0
for module in self.model.modules():
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
# Restore Batchnorm forward
# torch.nn.modules.batchnorm._BatchNorm.forward
module.forward = self.forwards[cpt]
cpt += 1
pass
# if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
# # Restore Batchnorm forward
# # torch.nn.modules.batchnorm._BatchNorm.forward
# module.forward = self.forwards[cpt]
# cpt += 1
def wrap(torch_model: torch.nn.Module,
input_size: Union[list, tuple],
opset_version: int = 11,
opset_version: int = 18,
save_onnx_model:bool = False,
in_names: list = None,
out_names: list = None,
verbose: bool = False) -> AidgeModule:
......@@ -281,6 +284,8 @@ def wrap(torch_model: torch.nn.Module,
:type input_size: ``list``
:param opset_version: Opset version used to generate the intermediate ONNX file, default=11
:type opset_version: int, optional
:param save_onnx_model: If True intermediate onnx files are saved, default=False
:type save_onnx_model: bool, optional
:param in_names: Specify specific names for the network inputs
:type in_names: list, optional
:param out_names: Specify specific names for the network outputs
......@@ -290,7 +295,7 @@ def wrap(torch_model: torch.nn.Module,
:return: A custom ``torch.nn.Module`` which embed a :py:class:`aidge_core.GraphView`.
:rtype: :py:class:`aidge_interop_torch.AidgeModule`
"""
raw_model_path = f'./{torch_model.__class__.__name__}_raw.onnx'
raw_model_path = Path(f'./{torch_model.__class__.__name__}_raw.onnx')
model_path = f'./{torch_model.__class__.__name__}.onnx'
print("Exporting torch module to ONNX ...")
......@@ -302,30 +307,41 @@ def wrap(torch_model: torch.nn.Module,
dummy_in = torch.zeros(input_size).to(torch_device)
# Setting model to training
# Setting model to eval
# important to keep information with BatchNorm
torch_model.train()
# removing spam warning from pytorch
warnings.filterwarnings("ignore", message=".*Constant folding - Only steps=1 can be constant folded.*")
warnings.filterwarnings(
"ignore",
message="ONNX export mode is set to TrainingMode.EVAL, but operator 'batch_norm' is set to train=True. Exporting with train=True."
)
# Note : To keep batchnorm we export model in train mode.
# However we cannot freeze batchnorm stats in pytorch < 12 (see : https://github.com/pytorch/pytorch/issues/75252).
# And even in > 12 when stats freezed the ONNX graph drastically changes ...
# To deal with this issue we use a context which change the forward behavior of batchnorm to protect stats.
with ContextNoBatchNormFuse(torch_model) as ctx:
torch.onnx.export(torch_model,
dummy_in,
raw_model_path,
verbose=verbose,
input_names=in_names,
output_names=out_names,
export_params=True,
opset_version=opset_version,
do_constant_folding=False)
# with ContextNoBatchNormFuse(torch_model) as ctx:
torch.onnx.export(torch_model,
dummy_in,
raw_model_path,
verbose=verbose,
input_names=in_names,
output_names=out_names,
export_params=True,
opset_version=opset_version,
do_constant_folding=False
)
print("Simplifying the ONNX model ...")
onnx_model = onnx.load(raw_model_path)
raw_model_path.unlink()
model_simp, check = simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp, model_path)
if save_onnx_model:
onnx.save(model_simp, model_path)
aidge_model = aidge_onnx.onnx_import.convert_onnx_to_aidge(model_simp)
aidge_core.remove_flatten(aidge_model)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment