Skip to content
Snippets Groups Projects
Commit 39d2c8d6 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Fix bug due to check of value before testing if None. Headers is now a set...

Fix bug due to check of value before testing if None. Headers is now a set avoiding multiple include of the same header. Trying to export an unsupported operator now raise an error.
parent d909b06c
No related branches found
No related tags found
3 merge requests!27v0.2.0,!22v0.4.0,!15Export refactor
...@@ -58,7 +58,8 @@ def export(export_folder_name, graphview, scheduler): ...@@ -58,7 +58,8 @@ def export(export_folder_name, graphview, scheduler):
# For forward file # For forward file
list_actions = op.forward(list_actions) list_actions = op.forward(list_actions)
else:
raise RuntimeError(f"Operator not supported: {node.type()} !")
# Memory management # Memory management
mem_size, mem_info = compute_default_mem_info(scheduler) mem_size, mem_info = compute_default_mem_info(scheduler)
...@@ -76,17 +77,15 @@ def export(export_folder_name, graphview, scheduler): ...@@ -76,17 +77,15 @@ def export(export_folder_name, graphview, scheduler):
# Get entry nodes # Get entry nodes
# Store the datatype & name # Store the datatype & name
list_inputs_name = [] list_inputs_name = []
print(graphview.get_input_nodes())
for node in graphview.get_input_nodes(): for node in graphview.get_input_nodes():
for node_input, outidx in node.inputs(): for idx, node_input_tuple in enumerate(node.inputs()):
node_input, _ = node_input_tuple
if node_input not in graphview.get_nodes(): if node_input is None:
# Case where export_type = aidge2c(node.get_operator().get_output(0).dtype())
list_inputs_name.append((export_type, f"{node.name()}_{idx}"))
elif node_input not in graphview.get_nodes():
export_type = aidge2c(node_input.get_operator().get_output(0).dtype()) export_type = aidge2c(node_input.get_operator().get_output(0).dtype())
list_inputs_name.append((export_type, node_input.name())) list_inputs_name.append((export_type, node_input.name()))
elif node_input is None:
export_type = aidge2c(node.get_operator().get_output(0).dtype())
list_inputs_name.append((export_type, f"{node.name()}_{outidx}"))
# Get output nodes # Get output nodes
...@@ -101,7 +100,7 @@ def export(export_folder_name, graphview, scheduler): ...@@ -101,7 +100,7 @@ def export(export_folder_name, graphview, scheduler):
generate_file( generate_file(
str(dnn_folder / "src" / "forward.cpp"), str(dnn_folder / "src" / "forward.cpp"),
str(ROOT / "templates" / "network" / "network_forward.jinja"), str(ROOT / "templates" / "network" / "network_forward.jinja"),
headers=list_configs, headers=set(list_configs),
actions=list_actions, actions=list_actions,
inputs= list_inputs_name, inputs= list_inputs_name,
outputs=list_outputs_name outputs=list_outputs_name
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment