diff --git a/aidge_onnx/onnx_import.py b/aidge_onnx/onnx_import.py index 59b4f42ea618d8b8c4d4d21277fe64f859f6d30b..bb97523dea5beb9377ae358e9e4ce9d16d7bc5fc 100644 --- a/aidge_onnx/onnx_import.py +++ b/aidge_onnx/onnx_import.py @@ -92,8 +92,8 @@ def _load_onnx2graphview(model:onnx.ModelProto, verbose:bool = False): :rtype: :py:class:`aidge_core.GraphView` """ opset: int = None - if hasattr(model, 'ir_version'): - opset = model.ir_version + if hasattr(model, 'opset_import'): + domains = {domain.domain : domain.version for domain in model.opset_import} else: raise RuntimeError("Cannot retieve opset version from ONNX model.") Log.info(f"ONNX metadata:" \ @@ -122,6 +122,10 @@ def _load_onnx2graphview(model:onnx.ModelProto, verbose:bool = False): for onnx_node in model.graph.node: node_name = onnx_node.output[0] # Do not use onnx_node.name as it is not a mandatory value node_inputs[node_name] = [None]*len(onnx_node.input) + # There can be multiple opsets in a given model, each ones attached to a given domain + # Each nodes are attached to a given opset via a domain name. + # more on how opset work here : http://onnx.ai/sklearn-onnx/auto_tutorial/plot_cbegin_opset.html + node_opset = domains[onnx_node.domain] # Adding producers to the list of inputs for input_idx, input_node in enumerate(onnx_node.input): @@ -129,7 +133,7 @@ def _load_onnx2graphview(model:onnx.ModelProto, verbose:bool = False): node_inputs[node_name][input_idx] = (model_producers[input_node], 0) try: - model_nodes[node_name] = ONNX_NODE_CONVERTER_[onnx_node.op_type.lower()](onnx_node, node_inputs[node_name], opset) + model_nodes[node_name] = ONNX_NODE_CONVERTER_[onnx_node.op_type.lower()](onnx_node, node_inputs[node_name], node_opset) except Exception as e: Log.warn(f"An error occured when trying to load node {node_name} of type {onnx_node.op_type.lower()}." f"Loading node using a generic operator."