diff --git a/aidge_export_arm_cortexm/export.py b/aidge_export_arm_cortexm/export.py index 110e9431924ae6ad45a9d6fa852a1201e4779428..d1f6595e89eb327ddb67b39a21d708c9078ea3ab 100644 --- a/aidge_export_arm_cortexm/export.py +++ b/aidge_export_arm_cortexm/export.py @@ -9,11 +9,16 @@ from aidge_core.mem_info import compute_default_mem_info, generate_optimized_mem from aidge_core.export_utils import scheduler_export +BOARD_PATH : str = ROOT / "boards" + +BOARDS_MAP: dict[str, Path] = { + "stm32h7" : BOARD_PATH / "stm32" / "H7", +} + def export(export_folder_name, graphview, scheduler = None, board:str ="stm32h7", - library:str = "aidge", mem_wrapping = False): scheduler_export( @@ -21,137 +26,30 @@ def export(export_folder_name, export_folder_name, ExportLibAidgeARM, memory_manager=generate_optimized_memory_info, - memory_manager_args={"stats_folder": f"{export_folder_name}/stats", "wrapping":False } + memory_manager_args={"stats_folder": f"{export_folder_name}/stats", "wrapping": mem_wrapping } ) - # Create export directory - export_folder = Path().absolute() / export_folder_name - os.makedirs(str(export_folder), exist_ok=True) - # Create dnn directory - dnn_folder = export_folder / "dnn" + gen_board_files(export_folder_name, board) + + +def supported_boards() -> list[str]: + return BOARDS_MAP.keys() + +def gen_board_files(path:str, board:str)->None: + if board not in supported_boards(): + raise ValueError(f"Board {board} is not supported, supported board are:\n\t-{'\n\t-'.join(supported_boards())}") + + if isinstance(path, str): path = Path(path) + # Create dnn directory is not exist + dnn_folder = path / "dnn" os.makedirs(str(dnn_folder), exist_ok=True) # Determine which board the user wants # to select correct config - board_path = ROOT / "boards" / "stm32" / "H7" + # Copy all static files in the export - shutil.copytree(board_path, str(export_folder), dirs_exist_ok=True) + shutil.copytree(BOARDS_MAP[board], str(path), dirs_exist_ok=True) # For N2D2 library, copy static folder to export/include dnn_include_folder = dnn_folder / "include" os.makedirs(str(dnn_include_folder), exist_ok=True) shutil.copytree(str(ROOT / "_N2D2" / "static"), str(dnn_include_folder), dirs_exist_ok=True) - - # # Create statistics directory - # stats_folder = export_folder / "statistics" - # os.makedirs(str(stats_folder), exist_ok=True) - - # # Sort layers according to a scheduler - # if not isinstance(scheduler, aidge_core.Scheduler): - # # No scheduler provided by the user, use the default one - # list_forward_nodes = topological_sort(graphview) - # mem_size, mem_info = compute_default_mem_info(list_forward_nodes) - # else: - # list_forward_nodes = scheduler.get_static_scheduling() - # mem_size, mem_info = generate_optimized_memory_info(stats_folder, scheduler, mem_wrapping) - - # Set some lists of elements for generating forward file - # list_actions = [] - # list_configs = [] - - # # Export layer configurations - # for node in list_forward_nodes: - # if node.type() == "Producer": - # # We do not treat Producer here but i the nodes which use them - # continue - - # if node.type() in supported_operators(): - # op = OPERATORS_REGISTRY[node.type()](node, board, library) - # # Export the configuration - # list_configs = op.export(dnn_folder, list_configs) - - # # Add forward kernel - # list_actions = op.forward(list_actions) - # else: - # print(f"Warning: {node.type()} is not supported in the export.\nPlease add the implementation.") - - # Generate the memory file - # generate_file( - # str(dnn_folder / "memory" / "mem_info.h"), - # str(ROOT / "templates" / "memory" / "mem_info.jinja"), - # mem_size = mem_size, - # mem_info_legends = MEMORY_INFO_TEMPLATE, - # mem_info = mem_info, - # mem_alignment = 1 # Fixed memory alignement so far, feel free to adapt it - # ) - # list_configs.append("memory/mem_info.h") - - # Get entry nodes - # It supposes the entry nodes are producers with constant=false - # Store the datatype & name - # list_inputs_name = [] - # first_element_added = False - # for node in graphview.get_nodes(): - # if node.type() == "Producer": - # if not first_element_added: - # export_type = aidge2c(node.get_operator().get_output(0).dtype()) - # list_inputs_name.append((export_type, node.name())) - # first_element_added = True - # if not node.get_operator().attr.constant: - # export_type = aidge2c(node.get_operator().get_output(0).dtype()) - # list_inputs_name.append((export_type, node.name())) - - # Get output nodes - # Store the datatype & name, like entry nodes - - # list_outputs_name = [] - # for node in graphview.get_nodes(): - # if len(node.get_children()) == 0: - # if node.get_operator().attr.has_attr('dtype'): - # # Temporary fix because impossible to set DataType of a generic operator - # export_type = aidge2c(node.get_operator().attr.dtype) - # else: - # export_type = aidge2c(node.get_operator().get_output(0).dtype()) - - # list_outputs_name.append((export_type, node.name())) - - # if library == "n2d2": - # forward_file = "forward.cpp" - # else: - # forward_file = "forward.c" - - # # Generate forward file - # generate_file( - # str(dnn_folder / "src" / forward_file), - # str(ROOT / "templates" / "network" / "network_forward.jinja"), - # headers=set(list_configs), - # actions=list_actions, - # inputs= list_inputs_name, - # outputs=list_outputs_name - # ) - - # # Generate dnn internal API - # if library == "aidge": - # # For Aidge, parse all kernels source code and retrieve function prototypes - # generate_file( - # str(dnn_folder / "include" / "network_functions.h"), - # str(ROOT / "templates" / "network" / "network_prototypes.jinja"), - # libraries=[], - # functions=get_functions_from_c_folder(str(dnn_folder / "src" / "kernels")), - # ) - # elif library == "n2d2": - # # For N2D2, parse all the files in include/kernel/ and retrieve the names of the files - # generate_file( - # str(dnn_folder / "include" / "network_functions.h"), - # str(ROOT / "templates" / "network" / "network_prototypes.jinja"), - # libraries=[], - # files=[str(Path("kernels") / x) for x in get_filenames_from_folder(str(dnn_folder / "include" / "kernels"), r'^.*\.hpp$')], - # ) - - # # Generate dnn API - # generate_file( - # str(dnn_folder / "include" / "dnn.h"), - # str(ROOT / "templates" / "network" / "dnn_header.jinja"), - # libraries=["stdint.h"], - # functions=get_functions_from_c_file(str(dnn_folder / "src" / forward_file)), - # ) - diff --git a/aidge_export_arm_cortexm/operators.py b/aidge_export_arm_cortexm/operators.py index b4e6aa85cd8709dd0944c4eaeb2b25560ab8b945..69281a8dd1e17d296ece79bea186c103b6232f36 100644 --- a/aidge_export_arm_cortexm/operators.py +++ b/aidge_export_arm_cortexm/operators.py @@ -197,7 +197,6 @@ class Scaling(): class ReLU_ARMCortexM(ExportNodeCpp): def __init__(self, node, mem_info, is_input, is_output): super().__init__(node, mem_info, is_input, is_output) - self.attributes["activation_type"] = "RELU" self.config_template = str(ROOT / "_Aidge_Arm" / "templates" / "configuration" / "relu.jinja")