diff --git a/aidge_export_arm_cortexm/export.py b/aidge_export_arm_cortexm/export.py index cff56ed77e20a904e6e2d8b605bb318c3978d5ee..2041ca3ad6545607ef37ac3114c7d2a3be24fe2f 100644 --- a/aidge_export_arm_cortexm/export.py +++ b/aidge_export_arm_cortexm/export.py @@ -9,7 +9,7 @@ from aidge_export_arm_cortexm.utils import (ROOT, AVAILABLE_BOARDS, has_board, \ OPERATORS_REGISTRY, supported_operators) import aidge_export_arm_cortexm.operators from aidge_export_arm_cortexm.utils.scheduler import topological_sort -from aidge_export_arm_cortexm.utils.generation import get_functions_from_c_file, get_functions_from_c_folder +from aidge_export_arm_cortexm.utils.generation import get_functions_from_c_file, get_functions_from_c_folder, get_filenames_from_folder from aidge_export_arm_cortexm.utils.converter import * from aidge_export_arm_cortexm.memory import * @@ -19,7 +19,8 @@ def export(export_folder_name, graphview, scheduler = None, board:str ="stm32h7", - library:str = "aidge"): + library:str = "aidge", + mem_wrapping = False): # Create export directory export_folder = Path().absolute() / export_folder_name @@ -39,14 +40,24 @@ def export(export_folder_name, # Copy all static files in the export shutil.copytree(board_path, str(export_folder), dirs_exist_ok=True) + # For N2D2 library, copy static folder to export/include + if library == "n2d2": + dnn_include_folder = dnn_folder / "include" + os.makedirs(str(dnn_include_folder), exist_ok=True) + shutil.copytree(str(ROOT / "_N2D2" / "static"), str(dnn_include_folder), dirs_exist_ok=True) + + # Create statistics directory + stats_folder = export_folder / "statistics" + os.makedirs(str(stats_folder), exist_ok=True) + # Sort layers according to a scheduler - if not scheduler: - # No scheduler provided by the user - # use the default scheduler + if not isinstance(scheduler, aidge_core.Scheduler): + # No scheduler provided by the user, use the default one list_forward_nodes = topological_sort(graphview) + mem_size, mem_info = compute_default_mem_info(list_forward_nodes) else: list_forward_nodes = scheduler.get_static_scheduling() - + mem_size, mem_info = generate_optimized_memory_info(stats_folder, scheduler, mem_wrapping) # Set some lists of elements for generating forward file list_actions = [] @@ -54,6 +65,10 @@ def export(export_folder_name, # Export layer configurations for node in list_forward_nodes: + if node.type() == "Producer": + # We do not treat Producer here but i the nodes which use them + continue + if node.type() in supported_operators(): op = OPERATORS_REGISTRY[node.type()](node, board, library) @@ -62,10 +77,8 @@ def export(export_folder_name, # Add forward kernel list_actions = op.forward(list_actions) - - # Memory management - # TODO put the condition if no scheduler provided - mem_size, mem_info = compute_default_mem_info(scheduler) + else: + print(f"Warning: {node.type()} is not supported in the export.\nPlease add the implementation.") # Generate the memory file generate_file( @@ -73,7 +86,8 @@ def export(export_folder_name, str(ROOT / "templates" / "memory" / "mem_info.jinja"), mem_size = mem_size, mem_info_legends = MEMORY_INFO_TEMPLATE, - mem_info = mem_info + mem_info = mem_info, + mem_alignment = 1 # Fixed memory alignement so far, feel free to adapt it ) list_configs.append("memory/mem_info.h") @@ -92,12 +106,22 @@ def export(export_folder_name, list_outputs_name = [] for node in graphview.get_nodes(): if len(node.get_children()) == 0: - export_type = aidge_datatype2ctype(node.get_operator().get_output(0).dtype()) + if node.get_operator().has_attr("DataType"): + # Temporary fix because impossible to set DataType of a generic operator + export_type = aidge_datatype2ctype(node.get_operator().get_attr("DataType")) + else: + export_type = aidge_datatype2ctype(node.get_operator().get_output(0).dtype()) + list_outputs_name.append((export_type, node.name())) + if library == "n2d2": + forward_file = "forward.cpp" + else: + forward_file = "forward.c" + # Generate forward file generate_file( - str(dnn_folder / "src" / "forward.c"), + str(dnn_folder / "src" / forward_file), str(ROOT / "templates" / "network" / "network_forward.jinja"), headers=list_configs, actions=list_actions, @@ -106,18 +130,28 @@ def export(export_folder_name, ) # Generate dnn internal API - aidge_export_arm_cortexm.operators.generate_file( - str(dnn_folder / "include" / "network_functions.h"), - str(ROOT / "templates" / "network" / "dnn_header.jinja"), - libraries=[], - functions=get_functions_from_c_folder(str(dnn_folder / "src" / "kernels")), - ) + if library == "aidge": + # For Aidge, parse all kernels source code and retrieve function prototypes + generate_file( + str(dnn_folder / "include" / "network_functions.h"), + str(ROOT / "templates" / "network" / "network_prototypes.jinja"), + libraries=[], + functions=get_functions_from_c_folder(str(dnn_folder / "src" / "kernels")), + ) + elif library == "n2d2": + # For N2D2, parse all the files in include/kernel/ and retrieve the names of the files + generate_file( + str(dnn_folder / "include" / "network_functions.h"), + str(ROOT / "templates" / "network" / "network_prototypes.jinja"), + libraries=[], + files=[str(Path("kernels") / x) for x in get_filenames_from_folder(str(dnn_folder / "include" / "kernels"), r'^.*\.hpp$')], + ) # Generate dnn API - aidge_export_arm_cortexm.operators.generate_file( + generate_file( str(dnn_folder / "include" / "dnn.h"), str(ROOT / "templates" / "network" / "dnn_header.jinja"), - libraries=[], - functions=get_functions_from_c_file(str(dnn_folder / "src" / "forward.c")), + libraries=["stdint.h"], + functions=get_functions_from_c_file(str(dnn_folder / "src" / forward_file)), )