Skip to content
Snippets Groups Projects
Commit b3fe7fd0 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Update forward and export signature.

parent 91458013
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!163Export refactor
......@@ -269,12 +269,7 @@ class ExportNodeCpp(ExportNode):
# Path where all the kernels are stored in the export (prefixed by export_root/include)
kernels_path: str = "kernels"
# def __init__(self, aidge_node: aidge_core.Node) -> None:
# """Create ExportNode and retieve attriubtes from ``aidge_node``:
# """
# super().__init__()
def export(self, export_folder: str, list_configs: list):
def export(self, export_folder: str):
"""Define how to export the node definition.
"""
if self.config_template is None:
......@@ -302,15 +297,15 @@ class ExportNodeCpp(ExportNode):
kernel_include_list.append(path_to_definition)
return list_configs + self.include_list + kernel_include_list
return self.include_list + kernel_include_list
def forward(self, list_actions: list):
def forward(self):
"""Define how to generate code to perform a forward pass.
"""
if self.forward_template is None:
raise ValueError("forward_template have not been defined")
list_actions.append(code_generation.generate_str(
forward_call: str = code_generation.generate_str(
self.forward_template,
**self.attributes
))
return list_actions
)
return [forward_call]
......@@ -6,31 +6,22 @@ from aidge_core.export_utils import ExportLib, generate_file, copy_file
from typing import List, Tuple
class ExportScheduler():
def __init__(self, scheduler: aidge_core.Scheduler, *args, **kwargs):
self.scheduler = scheduler
self.graphview = scheduler.graph_view()
def export(self,
export_folder_path: str,
export_lib: ExportLib = None,
platform=None,
memory_manager=None,
memory_manager_args=None
) -> None:
def export(scheduler, export_folder_path: str, export_lib: ExportLib = None, memory_manager=None, memory_manager_args=None) -> None:
graphview = scheduler.graph_view()
export_folder = Path().absolute() / export_folder_path
os.makedirs(str(export_folder), exist_ok=True)
dnn_folder = export_folder / "dnn"
os.makedirs(str(dnn_folder), exist_ok=True)
if memory_manager_args is None:
memory_manager_args = {}
if memory_manager is None:
raise ValueError("A memory manager is required (no default value yet).")
peak_mem, mem_info = memory_manager(
self.scheduler, **memory_manager_args)
scheduler, **memory_manager_args)
# List of function call
list_actions: List[str] = []
......@@ -43,38 +34,37 @@ class ExportScheduler():
outputs_dtype: List[str] = []
outputs_size: List[int] = []
list_forward_nodes = self.scheduler.get_static_scheduling()
list_forward_nodes = scheduler.get_static_scheduling()
# If exportLib define use it
# else parse component in platform
if export_lib is not None:
for node in list_forward_nodes:
if export_lib.exportable(node):
is_input = node in self.graphview.get_input_nodes()
is_output = node in self.graphview.get_output_nodes()
op = export_lib.get_export_node(node)(
node, mem_info[node], is_input, is_output)
# For configuration files
list_configs = op.export(dnn_folder, list_configs)
# For forward file
list_actions = op.forward(list_actions)
if is_input:
for idx, node in enumerate(node.inputs()):
if node[0] not in self.graphview.get_nodes():
inputs_name.append(op.attributes["in_name"][idx])
inputs_dtype.append(
op.attributes["in_cdtype"][idx]
)
if is_output:
for idx in range(len(node.outputs())):
outputs_name.append(op.attributes["out_name"][idx])
outputs_dtype.append(
op.attributes["out_cdtype"][idx])
outputs_size.append(op.attributes["out_size"][idx])
else:
raise RuntimeError(
f"Operator not supported: {node.type()} for export lib {export_lib._name} !")
else:
raise ValueError("Current export only support export lib.")
if export_lib is None:
raise ValueError("Export need an ExportLib.")
for node in list_forward_nodes:
if export_lib.exportable(node):
is_input = node in graphview.get_input_nodes()
is_output = node in graphview.get_output_nodes()
op = export_lib.get_export_node(node)(
node, mem_info[node], is_input, is_output)
# For configuration files
list_configs += op.export(dnn_folder)
# For forward file
list_actions += op.forward()
if is_input:
for idx, node in enumerate(node.inputs()):
if node[0] not in graphview.get_nodes():
inputs_name.append(op.attributes["in_name"][idx])
inputs_dtype.append(
op.attributes["in_cdtype"][idx]
)
if is_output:
for idx in range(len(node.outputs())):
outputs_name.append(op.attributes["out_name"][idx])
outputs_dtype.append(
op.attributes["out_cdtype"][idx])
outputs_size.append(op.attributes["out_size"][idx])
else:
raise RuntimeError(
f"Operator not supported: {node.type()} for export lib {export_lib._name} !")
func_name = "model_forward"
args = ", ".join([f"const {dtype}* {name}" for name,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment