From fb1f9a843a9fcf89ab79c35e4bb34339430df404 Mon Sep 17 00:00:00 2001 From: Cyril Moineau <cyril.moineau@cea.fr> Date: Fri, 8 Nov 2024 16:08:18 +0000 Subject: [PATCH] Remove main.cpp generation from export_scheduler --- aidge_core/export_utils/__init__.py | 2 +- aidge_core/export_utils/export_registry.py | 4 +- aidge_core/export_utils/generate_main.py | 51 +++++++++++++++++++++ aidge_core/export_utils/scheduler_export.py | 26 +---------- 4 files changed, 54 insertions(+), 29 deletions(-) create mode 100644 aidge_core/export_utils/generate_main.py diff --git a/aidge_core/export_utils/__init__.py b/aidge_core/export_utils/__init__.py index a97e97874..b17ff90d6 100644 --- a/aidge_core/export_utils/__init__.py +++ b/aidge_core/export_utils/__init__.py @@ -3,4 +3,4 @@ from .code_generation import generate_file, generate_str, copy_file from .export_registry import ExportLib from .scheduler_export import scheduler_export from .tensor_export import tensor_to_c, generate_input_file - +from .generate_main import generate_main_cpp diff --git a/aidge_core/export_utils/export_registry.py b/aidge_core/export_utils/export_registry.py index fd24008a6..70c3e5fa4 100644 --- a/aidge_core/export_utils/export_registry.py +++ b/aidge_core/export_utils/export_registry.py @@ -28,10 +28,8 @@ class ExportLib(aidge_core.OperatorImpl): # key: Path where static file is # Value: Path where to copy the file relative to the export root static_files: Dict[str, str] = {} - # Custom main generation jinja file - main_jinja_path = None # Main memory section - memory_section = None + mem_section = None # PRIVATE # Registry of exportNode, class level dictionary, shared across all ExportLib _cls_export_node_registry = {} diff --git a/aidge_core/export_utils/generate_main.py b/aidge_core/export_utils/generate_main.py new file mode 100644 index 000000000..b7eee9306 --- /dev/null +++ b/aidge_core/export_utils/generate_main.py @@ -0,0 +1,51 @@ +import aidge_core +from pathlib import Path +from aidge_core.export_utils import generate_file, data_conversion + +def generate_main_cpp(export_folder: str, graph_view: aidge_core.GraphView) -> None: + """ + Generate a C++ file to manage the forward pass of a model using the given graph structure. + + This function extracts details from the :py:class:`aidge_core.graph_view` object, including input and output node names, data types, + and tensor sizes. It uses this data to populate a C++ file template (`main.jinja`), creating a file (`main.cpp`) + that call the `model_forward` function, which handles data flow and processing for the exported model. + + :param export_folder: Path to the folder where the generated C++ file (`main.cpp`) will be saved. + :type export_folder: str + :param graph_view: An instance of :py:class:`aidge_core.graph_view`, providing access to nodes and + ordered input/output data within the computational graph. + :type graph_view: aidge_core.graph_view + :raises RuntimeError: If there is an inconsistency in the output arguments (names, data types, sizes), + indicating an internal bug in the graph representation. + """ + outputs_name: list[str] = [] + outputs_dtype: list[str] = [] + outputs_size: list[int] = [] + inputs_name: list[str] = [] + gv_inputs: list[tuple[aidge_core.Node, int]] = graph_view.get_ordered_inputs() + gv_outputs: list[tuple[aidge_core.Node, int]] = graph_view.get_ordered_outputs() + + for in_node, in_idx in gv_inputs: + in_node_input, in_node_input_idx = in_node.input(in_idx) + inputs_name.append(f"{in_node.name()}_input_{in_idx}" if in_node_input is None else f"{in_node_input.name()}_output_{in_node_input_idx}") + for out_node, out_id in gv_outputs: + outputs_name.append(f"{out_node.name()}_output_{out_id}") + out_tensor = out_node.get_operator().get_output(out_id) + outputs_dtype.append(data_conversion.aidge2c(out_tensor.dtype())) + outputs_size.append(out_tensor.size()) + print(out_tensor.size()) + + + if len(outputs_name) != len(outputs_dtype) or len(outputs_name) != len(outputs_size): + raise RuntimeError("FATAL: Output args list does not have the same length this is an internal bug.") + + ROOT = Path(__file__).resolve().parents[0] + generate_file( + str(Path(export_folder) / "main.cpp"), + str(ROOT / "templates" / "main.jinja"), + func_name="model_forward", + inputs_name=inputs_name, + outputs_name=outputs_name, + outputs_dtype=outputs_dtype, + outputs_size=outputs_size + ) diff --git a/aidge_core/export_utils/scheduler_export.py b/aidge_core/export_utils/scheduler_export.py index df0b4a385..f1de6f823 100644 --- a/aidge_core/export_utils/scheduler_export.py +++ b/aidge_core/export_utils/scheduler_export.py @@ -98,14 +98,6 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = outputs_size.append(op.attributes["out_size"][idx]) func_name = "model_forward" - - - args = ", ".join([f"const {dtype}* {name}" for name, - dtype in zip(inputs_name, inputs_dtype)]) - args += ", " +", ".join([f"{dtype}** {name}" for name, - dtype in zip(outputs_name, outputs_dtype)]) - forward_func = f"void {func_name}({args})" - ROOT = Path(__file__).resolve().parents[0] generate_file( str(dnn_folder / "src" / "forward.cpp"), @@ -114,7 +106,7 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = headers=set(list_configs), actions=list_actions, mem_ctype=inputs_dtype[0], # Legacy behavior ... - mem_section=export_lib.mem_section, + mem_section=export_lib.mem_section, peak_mem=peak_mem, inputs_name=inputs_name, inputs_dtype=inputs_dtype, @@ -137,22 +129,6 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = if len(outputs_name) != len(outputs_dtype) or len(outputs_name) != len(outputs_size): raise RuntimeError("FATAL: Output args list does not have the same length this is an internal bug.") - if export_lib is not None and export_lib.main_jinja_path is not None: - main_jinja_path = export_lib.main_jinja_path - else : - main_jinja_path = str(ROOT / "templates" / "main.jinja") - - generate_file( - str(export_folder / "main.cpp"), - main_jinja_path, - func_name=func_name, - inputs_name=inputs_name, - outputs_name=outputs_name, - outputs_dtype=outputs_dtype, - outputs_size=outputs_size, - labels=labels - ) - if export_lib is not None: # Copy all static files in the export for source, destination in export_lib.static_files.items(): -- GitLab