diff --git a/aidge_interop_torch/utils.py b/aidge_interop_torch/utils.py index bdeabf80c1a55676f202c7d2aa80ba46fbd7157c..930e1899baea55561d981a36fca9fa0a20849e92 100644 --- a/aidge_interop_torch/utils.py +++ b/aidge_interop_torch/utils.py @@ -10,7 +10,8 @@ import aidge_learning from onnxsim import simplify from typing import Union - +from pathlib import Path +import warnings def convert_tensor(tensor): """Convert a torch tensor to :py:class:`aidge_core.Tensor` and vice versa. @@ -259,16 +260,18 @@ class ContextNoBatchNormFuse: """ cpt = 0 for module in self.model.modules(): - if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): - # Restore Batchnorm forward - # torch.nn.modules.batchnorm._BatchNorm.forward - module.forward = self.forwards[cpt] - cpt += 1 + pass + # if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + # # Restore Batchnorm forward + # # torch.nn.modules.batchnorm._BatchNorm.forward + # module.forward = self.forwards[cpt] + # cpt += 1 def wrap(torch_model: torch.nn.Module, input_size: Union[list, tuple], - opset_version: int = 11, + opset_version: int = 18, + save_onnx_model:bool = False, in_names: list = None, out_names: list = None, verbose: bool = False) -> AidgeModule: @@ -281,6 +284,8 @@ def wrap(torch_model: torch.nn.Module, :type input_size: ``list`` :param opset_version: Opset version used to generate the intermediate ONNX file, default=11 :type opset_version: int, optional + :param save_onnx_model: If True intermediate onnx files are saved, default=False + :type save_onnx_model: bool, optional :param in_names: Specify specific names for the network inputs :type in_names: list, optional :param out_names: Specify specific names for the network outputs @@ -290,7 +295,7 @@ def wrap(torch_model: torch.nn.Module, :return: A custom ``torch.nn.Module`` which embed a :py:class:`aidge_core.GraphView`. :rtype: :py:class:`aidge_interop_torch.AidgeModule` """ - raw_model_path = f'./{torch_model.__class__.__name__}_raw.onnx' + raw_model_path = Path(f'./{torch_model.__class__.__name__}_raw.onnx') model_path = f'./{torch_model.__class__.__name__}.onnx' print("Exporting torch module to ONNX ...") @@ -302,30 +307,41 @@ def wrap(torch_model: torch.nn.Module, dummy_in = torch.zeros(input_size).to(torch_device) - # Setting model to training + # Setting model to eval # important to keep information with BatchNorm torch_model.train() + # removing spam warning from pytorch + warnings.filterwarnings("ignore", message=".*Constant folding - Only steps=1 can be constant folded.*") + warnings.filterwarnings( + "ignore", + message="ONNX export mode is set to TrainingMode.EVAL, but operator 'batch_norm' is set to train=True. Exporting with train=True." + ) + + # Note : To keep batchnorm we export model in train mode. # However we cannot freeze batchnorm stats in pytorch < 12 (see : https://github.com/pytorch/pytorch/issues/75252). # And even in > 12 when stats freezed the ONNX graph drastically changes ... # To deal with this issue we use a context which change the forward behavior of batchnorm to protect stats. - with ContextNoBatchNormFuse(torch_model) as ctx: - torch.onnx.export(torch_model, - dummy_in, - raw_model_path, - verbose=verbose, - input_names=in_names, - output_names=out_names, - export_params=True, - opset_version=opset_version, - do_constant_folding=False) + # with ContextNoBatchNormFuse(torch_model) as ctx: + torch.onnx.export(torch_model, + dummy_in, + raw_model_path, + verbose=verbose, + input_names=in_names, + output_names=out_names, + export_params=True, + opset_version=opset_version, + do_constant_folding=False + ) print("Simplifying the ONNX model ...") onnx_model = onnx.load(raw_model_path) + raw_model_path.unlink() model_simp, check = simplify(onnx_model) assert check, "Simplified ONNX model could not be validated" - onnx.save(model_simp, model_path) + if save_onnx_model: + onnx.save(model_simp, model_path) aidge_model = aidge_onnx.onnx_import.convert_onnx_to_aidge(model_simp) aidge_core.remove_flatten(aidge_model)