Skip to content
Snippets Groups Projects
Commit 7776fc6a authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Remove from input_nodes nodes which only have unconnected data that are Optionnal.

parent 1f18fa20
No related branches found
No related tags found
No related merge requests found
...@@ -50,9 +50,19 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = ...@@ -50,9 +50,19 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib =
if not isinstance(op_impl, ExportLib): if not isinstance(op_impl, ExportLib):
raise RuntimeError(f"Operator {node.name()}[{node.type()}] doesn't have an exportable backend ({op_impl}).") raise RuntimeError(f"Operator {node.name()}[{node.type()}] doesn't have an exportable backend ({op_impl}).")
is_input = node in graphview.get_input_nodes() is_input:bool = node in graphview.get_input_nodes()
is_output = node in graphview.get_output_nodes() is_output:bool = node in graphview.get_output_nodes()
if is_input:
# GraphView.get_inputs_nodes() returns the nodes that have an Input set to None or not in the graph
# However, some inputs are Optional and thus the node may not be an input of the graph!
# So we need ot check that all the inputs of the nodes or in the graph or not optional
# This is what the following code block is checking.
for idx, node_in in enumerate(node.inputs()):
optional:bool = node.get_operator().input_category(idx) == aidge_core.InputCategory.OptionalData
# Note: node_in is a Tuple(Node, out_idx)
in_graph:bool = node_in[0] in graphview.get_nodes()
is_input &= (in_graph or not optional)
required_specs = op_impl.get_required_spec() required_specs = op_impl.get_required_spec()
specs = op_impl.get_best_match(required_specs) specs = op_impl.get_best_match(required_specs)
...@@ -76,7 +86,8 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = ...@@ -76,7 +86,8 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib =
for idx in range(len(node.outputs())): for idx in range(len(node.outputs())):
outputs_name.append(op.attributes["out_name"][idx]) outputs_name.append(op.attributes["out_name"][idx])
outputs_dtype.append( outputs_dtype.append(
op.attributes["out_cdtype"][idx]) op.attributes["out_cdtype"][idx]
)
outputs_size.append(op.attributes["out_size"][idx]) outputs_size.append(op.attributes["out_size"][idx])
func_name = "model_forward" func_name = "model_forward"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment