diff --git a/aidge_core/export_utils/scheduler_export.py b/aidge_core/export_utils/scheduler_export.py index 33869ea5ada0839e328d5f9edc43dcb84defbe6f..467504b631d92aa76ef0ebc6394ecde525582cf6 100644 --- a/aidge_core/export_utils/scheduler_export.py +++ b/aidge_core/export_utils/scheduler_export.py @@ -23,9 +23,9 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = peak_mem, mem_info = memory_manager( scheduler, **memory_manager_args) - # List of function call + # List of function call for forward.cpp list_actions: List[str] = [] - # List of headers to include to get the configuration files + # List of headers for forward.cpp list_configs: List[str] = [] inputs_name: List[str] = [] @@ -34,7 +34,9 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = outputs_dtype: List[str] = [] outputs_size: List[int] = [] - list_forward_nodes = scheduler.get_static_scheduling() + # List of aidge_core.Node ordered by scheduler + list_forward_nodes: List[aidge_core.Node] = scheduler.get_static_scheduling() + # If exportLib define use it # else parse component in platform # if export_lib is None: @@ -64,13 +66,19 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = in_graph:bool = node_in[0] in graphview.get_nodes() is_input &= (in_graph or not optional) + # Get operator current specs required_specs = op_impl.get_required_spec() + # Get specs of the implementation that match current specs specs = op_impl.get_best_match(required_specs) + # Retrieve said implementation export_node = op_impl.get_export_node(specs) + if export_node is None: raise RuntimeError(f"Could not find export node for {node.name()}[{node.type()}].") + # Instanciate ExportNode op = export_node( node, mem_info[node], is_input, is_output) + # For configuration files list_configs += op.export(dnn_folder) # For forward file @@ -91,9 +99,11 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = outputs_size.append(op.attributes["out_size"][idx]) func_name = "model_forward" + + args = ", ".join([f"const {dtype}* {name}" for name, dtype in zip(inputs_name, inputs_dtype)]) - args += ", " +", ".join([f"{dtype}* {name}" for name, + args += ", " +", ".join([f"{dtype}** {name}" for name, dtype in zip(outputs_name, outputs_dtype)]) forward_func = f"void {func_name}({args})" @@ -101,11 +111,15 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = generate_file( str(dnn_folder / "src" / "forward.cpp"), str(ROOT / "templates" / "forward.jinja"), - forward_function=forward_func, + func_name=func_name, headers=set(list_configs), actions=list_actions, mem_ctype=inputs_dtype[0], # Legacy behavior ... - peak_mem=peak_mem + peak_mem=peak_mem, + inputs_name=inputs_name, + inputs_dtype=inputs_dtype, + outputs_name=outputs_name, + outputs_dtype=outputs_dtype ) # Generate dnn API @@ -113,11 +127,15 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = str(dnn_folder / "include" / "forward.hpp"), str(ROOT / "templates" / "forward_header.jinja"), libraries=[], - functions=[forward_func], + func_name=func_name, + inputs_name=inputs_name, + inputs_dtype=inputs_dtype, + outputs_name=outputs_name, + outputs_dtype=outputs_dtype ) if len(outputs_name) != len(outputs_dtype) or len(outputs_name) != len(outputs_size): - raise RuntimeError("FATAL: Output args list does not have the same lenght this is an internal bug.") + raise RuntimeError("FATAL: Output args list does not have the same length this is an internal bug.") generate_file( str(export_folder / "main.cpp"), diff --git a/aidge_core/export_utils/templates/forward.jinja b/aidge_core/export_utils/templates/forward.jinja index 3c3ef8a6c14ad9f97cdd90ac5604b16969b5e358..a58bd168a9a0d5611c988cafb3c13b410e47a27a 100644 --- a/aidge_core/export_utils/templates/forward.jinja +++ b/aidge_core/export_utils/templates/forward.jinja @@ -18,17 +18,20 @@ static {{mem_ctype}} mem[{{peak_mem}}]; {# Forward function #} {#- Support multiple inputs with different datatypes and multiple outputs with different datatypes -#} -{{ forward_function }} +void {{ func_name }} ( + {%- for i in range(inputs_name | length) %} + const {{ inputs_dtype[i] }}* {{ inputs_name[i] }}, + {% endfor -%} + {%- for o in range(outputs_name | length) %} + {{ outputs_dtype[o] }}** {{ outputs_name[o] }}_ptr{% if not loop.last %}, {% endif %} + {% endfor -%}) { - #ifdef SAVE_OUTPUTS - // Creation of the outputs directory - struct stat st {}; - if (stat("outputs", &st) == -1) { - mkdir("outputs", 0700); - } - #endif {%- for action in actions %} {{ action }} {%- endfor %} + + {%- for output_name in outputs_name %} + *{{ output_name }}_ptr = {{ output_name }}; + {%- endfor %} } diff --git a/aidge_core/export_utils/templates/forward_header.jinja b/aidge_core/export_utils/templates/forward_header.jinja index 5cf41742a555551429b420620dedc1bd16fae0c4..8530d91e995fc3c4e25b0e148f4cf1f4ded9ea0b 100644 --- a/aidge_core/export_utils/templates/forward_header.jinja +++ b/aidge_core/export_utils/templates/forward_header.jinja @@ -10,9 +10,13 @@ extern "C" { #include <{{ lib }}> {%- endfor %} -{% for func in functions %} -{{ func }}; -{% endfor %} +void {{ func_name }} ( + {%- for i in range(inputs_name | length) %} + const {{ inputs_dtype[i] }}* {{ inputs_name[i] }}, + {% endfor -%} + {%- for o in range(outputs_name | length) %} + {{ outputs_dtype[o] }}** {{ outputs_name[o] }}{% if not loop.last %}, {% endif %} + {% endfor -%}); #ifdef __cplusplus } diff --git a/aidge_core/export_utils/templates/main.jinja b/aidge_core/export_utils/templates/main.jinja index 2a563f5538d4b00591ca47f154bafbdb521ac7e9..d0c22719a5d9b9eaa15d3ef9ef86307a060b54be 100644 --- a/aidge_core/export_utils/templates/main.jinja +++ b/aidge_core/export_utils/templates/main.jinja @@ -22,11 +22,11 @@ int main() { // Initialize the output arrays {%- for o in range(outputs_name | length) %} - {{ outputs_dtype[o] }} {{ outputs_name[o] }}[{{ outputs_size[o] }}]; + {{ outputs_dtype[o] }}* {{ outputs_name[o] }} = nullptr; {% endfor %} // Call the forward function - {{ func_name }}({{ inputs_name|join(", ") }}, {{ outputs_name|join(", ") }}); + {{ func_name }}({{ inputs_name|join(", ") }}, &{{ outputs_name|join(", &") }}); // Print the results of each output {%- for o in range(outputs_name | length) %}