From 39d2c8d6a61fb9f76f5eea26ed975f1711c14e9d Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Mon, 1 Jul 2024 11:44:23 +0000
Subject: [PATCH] Fix bug due to check of value before testing if None. Headers
 is now a set avoiding multiple include of the same header. Trying to export
 an unsupported operator now raise an error.

---
 aidge_export_cpp/export.py | 19 +++++++++----------
 1 file changed, 9 insertions(+), 10 deletions(-)

diff --git a/aidge_export_cpp/export.py b/aidge_export_cpp/export.py
index 2e836f9..a57264a 100644
--- a/aidge_export_cpp/export.py
+++ b/aidge_export_cpp/export.py
@@ -58,7 +58,8 @@ def export(export_folder_name, graphview, scheduler):
 
             # For forward file
             list_actions = op.forward(list_actions)
-
+        else:
+            raise RuntimeError(f"Operator not supported: {node.type()} !")
 
     # Memory management
     mem_size, mem_info = compute_default_mem_info(scheduler)
@@ -76,17 +77,15 @@ def export(export_folder_name, graphview, scheduler):
     # Get entry nodes
     # Store the datatype & name
     list_inputs_name = []
-    print(graphview.get_input_nodes())
     for node in graphview.get_input_nodes():
-        for node_input, outidx in node.inputs():
-
-            if node_input not in graphview.get_nodes():
-                # Case where
+        for idx, node_input_tuple in enumerate(node.inputs()):
+            node_input, _ = node_input_tuple
+            if node_input is None:
+                export_type = aidge2c(node.get_operator().get_output(0).dtype())
+                list_inputs_name.append((export_type, f"{node.name()}_{idx}"))
+            elif node_input not in graphview.get_nodes():
                 export_type = aidge2c(node_input.get_operator().get_output(0).dtype())
                 list_inputs_name.append((export_type, node_input.name()))
-            elif node_input is None:
-                export_type = aidge2c(node.get_operator().get_output(0).dtype())
-                list_inputs_name.append((export_type, f"{node.name()}_{outidx}"))
 
 
     # Get output nodes
@@ -101,7 +100,7 @@ def export(export_folder_name, graphview, scheduler):
     generate_file(
         str(dnn_folder / "src" / "forward.cpp"),
         str(ROOT / "templates" / "network" / "network_forward.jinja"),
-        headers=list_configs,
+        headers=set(list_configs),
         actions=list_actions,
         inputs= list_inputs_name,
         outputs=list_outputs_name
-- 
GitLab