From fb1f9a843a9fcf89ab79c35e4bb34339430df404 Mon Sep 17 00:00:00 2001
From: Cyril Moineau <cyril.moineau@cea.fr>
Date: Fri, 8 Nov 2024 16:08:18 +0000
Subject: [PATCH] Remove main.cpp generation from export_scheduler

---
 aidge_core/export_utils/__init__.py         |  2 +-
 aidge_core/export_utils/export_registry.py  |  4 +-
 aidge_core/export_utils/generate_main.py    | 51 +++++++++++++++++++++
 aidge_core/export_utils/scheduler_export.py | 26 +----------
 4 files changed, 54 insertions(+), 29 deletions(-)
 create mode 100644 aidge_core/export_utils/generate_main.py

diff --git a/aidge_core/export_utils/__init__.py b/aidge_core/export_utils/__init__.py
index a97e97874..b17ff90d6 100644
--- a/aidge_core/export_utils/__init__.py
+++ b/aidge_core/export_utils/__init__.py
@@ -3,4 +3,4 @@ from .code_generation import generate_file, generate_str, copy_file
 from .export_registry import ExportLib
 from .scheduler_export import scheduler_export
 from .tensor_export import tensor_to_c, generate_input_file
-
+from .generate_main import generate_main_cpp
diff --git a/aidge_core/export_utils/export_registry.py b/aidge_core/export_utils/export_registry.py
index fd24008a6..70c3e5fa4 100644
--- a/aidge_core/export_utils/export_registry.py
+++ b/aidge_core/export_utils/export_registry.py
@@ -28,10 +28,8 @@ class ExportLib(aidge_core.OperatorImpl):
     # key: Path where static file is
     # Value: Path where to copy the file relative to the export root
     static_files: Dict[str, str] = {}
-    # Custom main generation jinja file
-    main_jinja_path = None
     # Main memory section
-    memory_section = None
+    mem_section = None
     # PRIVATE
     # Registry of exportNode, class level dictionary, shared across all ExportLib
     _cls_export_node_registry = {}
diff --git a/aidge_core/export_utils/generate_main.py b/aidge_core/export_utils/generate_main.py
new file mode 100644
index 000000000..b7eee9306
--- /dev/null
+++ b/aidge_core/export_utils/generate_main.py
@@ -0,0 +1,51 @@
+import aidge_core
+from pathlib import Path
+from aidge_core.export_utils import generate_file, data_conversion
+
+def generate_main_cpp(export_folder: str, graph_view: aidge_core.GraphView) -> None:
+    """
+    Generate a C++ file to manage the forward pass of a model using the given graph structure.
+
+    This function extracts details from the :py:class:`aidge_core.graph_view` object, including input and output node names, data types,
+    and tensor sizes. It uses this data to populate a C++ file template (`main.jinja`), creating a file (`main.cpp`)
+    that call the `model_forward` function, which handles data flow and processing for the exported model.
+
+    :param export_folder: Path to the folder where the generated C++ file (`main.cpp`) will be saved.
+    :type export_folder: str
+    :param graph_view: An instance of :py:class:`aidge_core.graph_view`, providing access to nodes and
+                       ordered input/output data within the computational graph.
+    :type graph_view: aidge_core.graph_view
+    :raises RuntimeError: If there is an inconsistency in the output arguments (names, data types, sizes),
+                          indicating an internal bug in the graph representation.
+    """
+    outputs_name: list[str] = []
+    outputs_dtype: list[str] = []
+    outputs_size: list[int] = []
+    inputs_name: list[str] = []
+    gv_inputs: list[tuple[aidge_core.Node, int]] = graph_view.get_ordered_inputs()
+    gv_outputs: list[tuple[aidge_core.Node, int]] = graph_view.get_ordered_outputs()
+
+    for in_node, in_idx in gv_inputs:
+        in_node_input, in_node_input_idx = in_node.input(in_idx)
+        inputs_name.append(f"{in_node.name()}_input_{in_idx}" if in_node_input is None else f"{in_node_input.name()}_output_{in_node_input_idx}")
+    for out_node, out_id in gv_outputs:
+        outputs_name.append(f"{out_node.name()}_output_{out_id}")
+        out_tensor = out_node.get_operator().get_output(out_id)
+        outputs_dtype.append(data_conversion.aidge2c(out_tensor.dtype()))
+        outputs_size.append(out_tensor.size())
+        print(out_tensor.size())
+
+
+    if len(outputs_name) != len(outputs_dtype) or len(outputs_name) != len(outputs_size):
+            raise RuntimeError("FATAL: Output args list does not have the same length this is an internal bug.")
+
+    ROOT = Path(__file__).resolve().parents[0]
+    generate_file(
+        str(Path(export_folder) / "main.cpp"),
+        str(ROOT / "templates" / "main.jinja"),
+        func_name="model_forward",
+        inputs_name=inputs_name,
+        outputs_name=outputs_name,
+        outputs_dtype=outputs_dtype,
+        outputs_size=outputs_size
+    )
diff --git a/aidge_core/export_utils/scheduler_export.py b/aidge_core/export_utils/scheduler_export.py
index df0b4a385..f1de6f823 100644
--- a/aidge_core/export_utils/scheduler_export.py
+++ b/aidge_core/export_utils/scheduler_export.py
@@ -98,14 +98,6 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib =
                     outputs_size.append(op.attributes["out_size"][idx])
 
         func_name = "model_forward"
-
-
-        args = ", ".join([f"const {dtype}* {name}" for name,
-                         dtype in zip(inputs_name, inputs_dtype)])
-        args += ", " +", ".join([f"{dtype}** {name}" for name,
-                          dtype in zip(outputs_name, outputs_dtype)])
-        forward_func = f"void {func_name}({args})"
-
         ROOT = Path(__file__).resolve().parents[0]
         generate_file(
             str(dnn_folder / "src" / "forward.cpp"),
@@ -114,7 +106,7 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib =
             headers=set(list_configs),
             actions=list_actions,
             mem_ctype=inputs_dtype[0],  # Legacy behavior ...
-            mem_section=export_lib.mem_section, 
+            mem_section=export_lib.mem_section,
             peak_mem=peak_mem,
             inputs_name=inputs_name,
             inputs_dtype=inputs_dtype,
@@ -137,22 +129,6 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib =
         if len(outputs_name) != len(outputs_dtype) or len(outputs_name) != len(outputs_size):
             raise RuntimeError("FATAL: Output args list does not have the same length this is an internal bug.")
 
-        if export_lib is not None and export_lib.main_jinja_path is not None:
-            main_jinja_path = export_lib.main_jinja_path
-        else :
-            main_jinja_path = str(ROOT / "templates" / "main.jinja")
-
-        generate_file(
-            str(export_folder / "main.cpp"),
-            main_jinja_path,
-            func_name=func_name,
-            inputs_name=inputs_name,
-            outputs_name=outputs_name,
-            outputs_dtype=outputs_dtype,
-            outputs_size=outputs_size,
-            labels=labels
-        )
-
         if export_lib is not None:
             # Copy all static files in the export
             for source, destination in export_lib.static_files.items():
-- 
GitLab