From 7776fc6a421d8d07780a822b17bb9153e38c843e Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Thu, 26 Sep 2024 08:11:41 +0000 Subject: [PATCH] Remove from input_nodes nodes which only have unconnected data that are Optionnal. --- aidge_core/export_utils/scheduler_export.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/aidge_core/export_utils/scheduler_export.py b/aidge_core/export_utils/scheduler_export.py index 5b50f803f..33869ea5a 100644 --- a/aidge_core/export_utils/scheduler_export.py +++ b/aidge_core/export_utils/scheduler_export.py @@ -50,9 +50,19 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = if not isinstance(op_impl, ExportLib): raise RuntimeError(f"Operator {node.name()}[{node.type()}] doesn't have an exportable backend ({op_impl}).") - is_input = node in graphview.get_input_nodes() - is_output = node in graphview.get_output_nodes() + is_input:bool = node in graphview.get_input_nodes() + is_output:bool = node in graphview.get_output_nodes() + if is_input: + # GraphView.get_inputs_nodes() returns the nodes that have an Input set to None or not in the graph + # However, some inputs are Optional and thus the node may not be an input of the graph! + # So we need ot check that all the inputs of the nodes or in the graph or not optional + # This is what the following code block is checking. + for idx, node_in in enumerate(node.inputs()): + optional:bool = node.get_operator().input_category(idx) == aidge_core.InputCategory.OptionalData + # Note: node_in is a Tuple(Node, out_idx) + in_graph:bool = node_in[0] in graphview.get_nodes() + is_input &= (in_graph or not optional) required_specs = op_impl.get_required_spec() specs = op_impl.get_best_match(required_specs) @@ -76,7 +86,8 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = for idx in range(len(node.outputs())): outputs_name.append(op.attributes["out_name"][idx]) outputs_dtype.append( - op.attributes["out_cdtype"][idx]) + op.attributes["out_cdtype"][idx] + ) outputs_size.append(op.attributes["out_size"][idx]) func_name = "model_forward" -- GitLab