From b3fe7fd0d2bdfced968f3ac4ab711c319519eea0 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Wed, 24 Jul 2024 12:51:15 +0000 Subject: [PATCH] Update forward and export signature. --- aidge_core/export_utils/node_export.py | 17 ++--- aidge_core/export_utils/scheduler_export.py | 76 +++++++++------------ 2 files changed, 39 insertions(+), 54 deletions(-) diff --git a/aidge_core/export_utils/node_export.py b/aidge_core/export_utils/node_export.py index ce3d53501..5eba1647c 100644 --- a/aidge_core/export_utils/node_export.py +++ b/aidge_core/export_utils/node_export.py @@ -269,12 +269,7 @@ class ExportNodeCpp(ExportNode): # Path where all the kernels are stored in the export (prefixed by export_root/include) kernels_path: str = "kernels" - # def __init__(self, aidge_node: aidge_core.Node) -> None: - # """Create ExportNode and retieve attriubtes from ``aidge_node``: - # """ - # super().__init__() - - def export(self, export_folder: str, list_configs: list): + def export(self, export_folder: str): """Define how to export the node definition. """ if self.config_template is None: @@ -302,15 +297,15 @@ class ExportNodeCpp(ExportNode): kernel_include_list.append(path_to_definition) - return list_configs + self.include_list + kernel_include_list + return self.include_list + kernel_include_list - def forward(self, list_actions: list): + def forward(self): """Define how to generate code to perform a forward pass. """ if self.forward_template is None: raise ValueError("forward_template have not been defined") - list_actions.append(code_generation.generate_str( + forward_call: str = code_generation.generate_str( self.forward_template, **self.attributes - )) - return list_actions + ) + return [forward_call] diff --git a/aidge_core/export_utils/scheduler_export.py b/aidge_core/export_utils/scheduler_export.py index 21e228c93..c12f15ab7 100644 --- a/aidge_core/export_utils/scheduler_export.py +++ b/aidge_core/export_utils/scheduler_export.py @@ -6,31 +6,22 @@ from aidge_core.export_utils import ExportLib, generate_file, copy_file from typing import List, Tuple -class ExportScheduler(): - def __init__(self, scheduler: aidge_core.Scheduler, *args, **kwargs): - self.scheduler = scheduler - self.graphview = scheduler.graph_view() - - def export(self, - export_folder_path: str, - export_lib: ExportLib = None, - platform=None, - memory_manager=None, - memory_manager_args=None - ) -> None: +def export(scheduler, export_folder_path: str, export_lib: ExportLib = None, memory_manager=None, memory_manager_args=None) -> None: + graphview = scheduler.graph_view() export_folder = Path().absolute() / export_folder_path os.makedirs(str(export_folder), exist_ok=True) dnn_folder = export_folder / "dnn" os.makedirs(str(dnn_folder), exist_ok=True) + if memory_manager_args is None: memory_manager_args = {} if memory_manager is None: raise ValueError("A memory manager is required (no default value yet).") peak_mem, mem_info = memory_manager( - self.scheduler, **memory_manager_args) + scheduler, **memory_manager_args) # List of function call list_actions: List[str] = [] @@ -43,38 +34,37 @@ class ExportScheduler(): outputs_dtype: List[str] = [] outputs_size: List[int] = [] - list_forward_nodes = self.scheduler.get_static_scheduling() + list_forward_nodes = scheduler.get_static_scheduling() # If exportLib define use it # else parse component in platform - if export_lib is not None: - for node in list_forward_nodes: - if export_lib.exportable(node): - is_input = node in self.graphview.get_input_nodes() - is_output = node in self.graphview.get_output_nodes() - op = export_lib.get_export_node(node)( - node, mem_info[node], is_input, is_output) - # For configuration files - list_configs = op.export(dnn_folder, list_configs) - # For forward file - list_actions = op.forward(list_actions) - if is_input: - for idx, node in enumerate(node.inputs()): - if node[0] not in self.graphview.get_nodes(): - inputs_name.append(op.attributes["in_name"][idx]) - inputs_dtype.append( - op.attributes["in_cdtype"][idx] - ) - if is_output: - for idx in range(len(node.outputs())): - outputs_name.append(op.attributes["out_name"][idx]) - outputs_dtype.append( - op.attributes["out_cdtype"][idx]) - outputs_size.append(op.attributes["out_size"][idx]) - else: - raise RuntimeError( - f"Operator not supported: {node.type()} for export lib {export_lib._name} !") - else: - raise ValueError("Current export only support export lib.") + if export_lib is None: + raise ValueError("Export need an ExportLib.") + for node in list_forward_nodes: + if export_lib.exportable(node): + is_input = node in graphview.get_input_nodes() + is_output = node in graphview.get_output_nodes() + op = export_lib.get_export_node(node)( + node, mem_info[node], is_input, is_output) + # For configuration files + list_configs += op.export(dnn_folder) + # For forward file + list_actions += op.forward() + if is_input: + for idx, node in enumerate(node.inputs()): + if node[0] not in graphview.get_nodes(): + inputs_name.append(op.attributes["in_name"][idx]) + inputs_dtype.append( + op.attributes["in_cdtype"][idx] + ) + if is_output: + for idx in range(len(node.outputs())): + outputs_name.append(op.attributes["out_name"][idx]) + outputs_dtype.append( + op.attributes["out_cdtype"][idx]) + outputs_size.append(op.attributes["out_size"][idx]) + else: + raise RuntimeError( + f"Operator not supported: {node.type()} for export lib {export_lib._name} !") func_name = "model_forward" args = ", ".join([f"const {dtype}* {name}" for name, -- GitLab