Skip to content
Snippets Groups Projects
Commit bf7c63df authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Merge branch 'RemoveMainFromExport' into 'dev'

Remove main.cpp generation from export_scheduler

See merge request !246
parents ecb77eed fb1f9a84
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!246Remove main.cpp generation from export_scheduler
Pipeline #58690 passed
......@@ -3,4 +3,4 @@ from .code_generation import generate_file, generate_str, copy_file
from .export_registry import ExportLib
from .scheduler_export import scheduler_export
from .tensor_export import tensor_to_c, generate_input_file
from .generate_main import generate_main_cpp
......@@ -28,10 +28,8 @@ class ExportLib(aidge_core.OperatorImpl):
# key: Path where static file is
# Value: Path where to copy the file relative to the export root
static_files: Dict[str, str] = {}
# Custom main generation jinja file
main_jinja_path = None
# Main memory section
memory_section = None
mem_section = None
# PRIVATE
# Registry of exportNode, class level dictionary, shared across all ExportLib
_cls_export_node_registry = {}
......
import aidge_core
from pathlib import Path
from aidge_core.export_utils import generate_file, data_conversion
def generate_main_cpp(export_folder: str, graph_view: aidge_core.GraphView) -> None:
"""
Generate a C++ file to manage the forward pass of a model using the given graph structure.
This function extracts details from the :py:class:`aidge_core.graph_view` object, including input and output node names, data types,
and tensor sizes. It uses this data to populate a C++ file template (`main.jinja`), creating a file (`main.cpp`)
that call the `model_forward` function, which handles data flow and processing for the exported model.
:param export_folder: Path to the folder where the generated C++ file (`main.cpp`) will be saved.
:type export_folder: str
:param graph_view: An instance of :py:class:`aidge_core.graph_view`, providing access to nodes and
ordered input/output data within the computational graph.
:type graph_view: aidge_core.graph_view
:raises RuntimeError: If there is an inconsistency in the output arguments (names, data types, sizes),
indicating an internal bug in the graph representation.
"""
outputs_name: list[str] = []
outputs_dtype: list[str] = []
outputs_size: list[int] = []
inputs_name: list[str] = []
gv_inputs: list[tuple[aidge_core.Node, int]] = graph_view.get_ordered_inputs()
gv_outputs: list[tuple[aidge_core.Node, int]] = graph_view.get_ordered_outputs()
for in_node, in_idx in gv_inputs:
in_node_input, in_node_input_idx = in_node.input(in_idx)
inputs_name.append(f"{in_node.name()}_input_{in_idx}" if in_node_input is None else f"{in_node_input.name()}_output_{in_node_input_idx}")
for out_node, out_id in gv_outputs:
outputs_name.append(f"{out_node.name()}_output_{out_id}")
out_tensor = out_node.get_operator().get_output(out_id)
outputs_dtype.append(data_conversion.aidge2c(out_tensor.dtype()))
outputs_size.append(out_tensor.size())
print(out_tensor.size())
if len(outputs_name) != len(outputs_dtype) or len(outputs_name) != len(outputs_size):
raise RuntimeError("FATAL: Output args list does not have the same length this is an internal bug.")
ROOT = Path(__file__).resolve().parents[0]
generate_file(
str(Path(export_folder) / "main.cpp"),
str(ROOT / "templates" / "main.jinja"),
func_name="model_forward",
inputs_name=inputs_name,
outputs_name=outputs_name,
outputs_dtype=outputs_dtype,
outputs_size=outputs_size
)
......@@ -98,14 +98,6 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib =
outputs_size.append(op.attributes["out_size"][idx])
func_name = "model_forward"
args = ", ".join([f"const {dtype}* {name}" for name,
dtype in zip(inputs_name, inputs_dtype)])
args += ", " +", ".join([f"{dtype}** {name}" for name,
dtype in zip(outputs_name, outputs_dtype)])
forward_func = f"void {func_name}({args})"
ROOT = Path(__file__).resolve().parents[0]
generate_file(
str(dnn_folder / "src" / "forward.cpp"),
......@@ -114,7 +106,7 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib =
headers=set(list_configs),
actions=list_actions,
mem_ctype=inputs_dtype[0], # Legacy behavior ...
mem_section=export_lib.mem_section,
mem_section=export_lib.mem_section,
peak_mem=peak_mem,
inputs_name=inputs_name,
inputs_dtype=inputs_dtype,
......@@ -137,22 +129,6 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib =
if len(outputs_name) != len(outputs_dtype) or len(outputs_name) != len(outputs_size):
raise RuntimeError("FATAL: Output args list does not have the same length this is an internal bug.")
if export_lib is not None and export_lib.main_jinja_path is not None:
main_jinja_path = export_lib.main_jinja_path
else :
main_jinja_path = str(ROOT / "templates" / "main.jinja")
generate_file(
str(export_folder / "main.cpp"),
main_jinja_path,
func_name=func_name,
inputs_name=inputs_name,
outputs_name=outputs_name,
outputs_dtype=outputs_dtype,
outputs_size=outputs_size,
labels=labels
)
if export_lib is not None:
# Copy all static files in the export
for source, destination in export_lib.static_files.items():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment