diff --git a/aidge_export_cpp/export.py b/aidge_export_cpp/export.py index 2e836f92d7a3bc8d3cdaa0aa7fb8a7b370d7a8f4..a57264ad21c27e48c2f68078647ef152397fa4fc 100644 --- a/aidge_export_cpp/export.py +++ b/aidge_export_cpp/export.py @@ -58,7 +58,8 @@ def export(export_folder_name, graphview, scheduler): # For forward file list_actions = op.forward(list_actions) - + else: + raise RuntimeError(f"Operator not supported: {node.type()} !") # Memory management mem_size, mem_info = compute_default_mem_info(scheduler) @@ -76,17 +77,15 @@ def export(export_folder_name, graphview, scheduler): # Get entry nodes # Store the datatype & name list_inputs_name = [] - print(graphview.get_input_nodes()) for node in graphview.get_input_nodes(): - for node_input, outidx in node.inputs(): - - if node_input not in graphview.get_nodes(): - # Case where + for idx, node_input_tuple in enumerate(node.inputs()): + node_input, _ = node_input_tuple + if node_input is None: + export_type = aidge2c(node.get_operator().get_output(0).dtype()) + list_inputs_name.append((export_type, f"{node.name()}_{idx}")) + elif node_input not in graphview.get_nodes(): export_type = aidge2c(node_input.get_operator().get_output(0).dtype()) list_inputs_name.append((export_type, node_input.name())) - elif node_input is None: - export_type = aidge2c(node.get_operator().get_output(0).dtype()) - list_inputs_name.append((export_type, f"{node.name()}_{outidx}")) # Get output nodes @@ -101,7 +100,7 @@ def export(export_folder_name, graphview, scheduler): generate_file( str(dnn_folder / "src" / "forward.cpp"), str(ROOT / "templates" / "network" / "network_forward.jinja"), - headers=list_configs, + headers=set(list_configs), actions=list_actions, inputs= list_inputs_name, outputs=list_outputs_name