From 39d2c8d6a61fb9f76f5eea26ed975f1711c14e9d Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Mon, 1 Jul 2024 11:44:23 +0000 Subject: [PATCH] Fix bug due to check of value before testing if None. Headers is now a set avoiding multiple include of the same header. Trying to export an unsupported operator now raise an error. --- aidge_export_cpp/export.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/aidge_export_cpp/export.py b/aidge_export_cpp/export.py index 2e836f9..a57264a 100644 --- a/aidge_export_cpp/export.py +++ b/aidge_export_cpp/export.py @@ -58,7 +58,8 @@ def export(export_folder_name, graphview, scheduler): # For forward file list_actions = op.forward(list_actions) - + else: + raise RuntimeError(f"Operator not supported: {node.type()} !") # Memory management mem_size, mem_info = compute_default_mem_info(scheduler) @@ -76,17 +77,15 @@ def export(export_folder_name, graphview, scheduler): # Get entry nodes # Store the datatype & name list_inputs_name = [] - print(graphview.get_input_nodes()) for node in graphview.get_input_nodes(): - for node_input, outidx in node.inputs(): - - if node_input not in graphview.get_nodes(): - # Case where + for idx, node_input_tuple in enumerate(node.inputs()): + node_input, _ = node_input_tuple + if node_input is None: + export_type = aidge2c(node.get_operator().get_output(0).dtype()) + list_inputs_name.append((export_type, f"{node.name()}_{idx}")) + elif node_input not in graphview.get_nodes(): export_type = aidge2c(node_input.get_operator().get_output(0).dtype()) list_inputs_name.append((export_type, node_input.name())) - elif node_input is None: - export_type = aidge2c(node.get_operator().get_output(0).dtype()) - list_inputs_name.append((export_type, f"{node.name()}_{outidx}")) # Get output nodes @@ -101,7 +100,7 @@ def export(export_folder_name, graphview, scheduler): generate_file( str(dnn_folder / "src" / "forward.cpp"), str(ROOT / "templates" / "network" / "network_forward.jinja"), - headers=list_configs, + headers=set(list_configs), actions=list_actions, inputs= list_inputs_name, outputs=list_outputs_name -- GitLab