Skip to content
Snippets Groups Projects
Commit cdf3c22d authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Update output nodes of exports to be pointer of pointer.

parent cff451cd
No related branches found
No related tags found
No related merge requests found
......@@ -23,9 +23,9 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib =
peak_mem, mem_info = memory_manager(
scheduler, **memory_manager_args)
# List of function call
# List of function call for forward.cpp
list_actions: List[str] = []
# List of headers to include to get the configuration files
# List of headers for forward.cpp
list_configs: List[str] = []
inputs_name: List[str] = []
......@@ -34,7 +34,9 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib =
outputs_dtype: List[str] = []
outputs_size: List[int] = []
list_forward_nodes = scheduler.get_static_scheduling()
# List of aidge_core.Node ordered by scheduler
list_forward_nodes: List[aidge_core.Node] = scheduler.get_static_scheduling()
# If exportLib define use it
# else parse component in platform
# if export_lib is None:
......@@ -64,13 +66,19 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib =
in_graph:bool = node_in[0] in graphview.get_nodes()
is_input &= (in_graph or not optional)
# Get operator current specs
required_specs = op_impl.get_required_spec()
# Get specs of the implementation that match current specs
specs = op_impl.get_best_match(required_specs)
# Retrieve said implementation
export_node = op_impl.get_export_node(specs)
if export_node is None:
raise RuntimeError(f"Could not find export node for {node.name()}[{node.type()}].")
# Instanciate ExportNode
op = export_node(
node, mem_info[node], is_input, is_output)
# For configuration files
list_configs += op.export(dnn_folder)
# For forward file
......@@ -91,9 +99,11 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib =
outputs_size.append(op.attributes["out_size"][idx])
func_name = "model_forward"
args = ", ".join([f"const {dtype}* {name}" for name,
dtype in zip(inputs_name, inputs_dtype)])
args += ", " +", ".join([f"{dtype}* {name}" for name,
args += ", " +", ".join([f"{dtype}** {name}" for name,
dtype in zip(outputs_name, outputs_dtype)])
forward_func = f"void {func_name}({args})"
......@@ -101,11 +111,15 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib =
generate_file(
str(dnn_folder / "src" / "forward.cpp"),
str(ROOT / "templates" / "forward.jinja"),
forward_function=forward_func,
func_name=func_name,
headers=set(list_configs),
actions=list_actions,
mem_ctype=inputs_dtype[0], # Legacy behavior ...
peak_mem=peak_mem
peak_mem=peak_mem,
inputs_name=inputs_name,
inputs_dtype=inputs_dtype,
outputs_name=outputs_name,
outputs_dtype=outputs_dtype
)
# Generate dnn API
......@@ -113,11 +127,15 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib =
str(dnn_folder / "include" / "forward.hpp"),
str(ROOT / "templates" / "forward_header.jinja"),
libraries=[],
functions=[forward_func],
func_name=func_name,
inputs_name=inputs_name,
inputs_dtype=inputs_dtype,
outputs_name=outputs_name,
outputs_dtype=outputs_dtype
)
if len(outputs_name) != len(outputs_dtype) or len(outputs_name) != len(outputs_size):
raise RuntimeError("FATAL: Output args list does not have the same lenght this is an internal bug.")
raise RuntimeError("FATAL: Output args list does not have the same length this is an internal bug.")
generate_file(
str(export_folder / "main.cpp"),
......
......@@ -18,17 +18,20 @@ static {{mem_ctype}} mem[{{peak_mem}}];
{# Forward function #}
{#- Support multiple inputs with different datatypes and multiple outputs with different datatypes -#}
{{ forward_function }}
void {{ func_name }} (
{%- for i in range(inputs_name | length) %}
const {{ inputs_dtype[i] }}* {{ inputs_name[i] }},
{% endfor -%}
{%- for o in range(outputs_name | length) %}
{{ outputs_dtype[o] }}** {{ outputs_name[o] }}_ptr{% if not loop.last %}, {% endif %}
{% endfor -%})
{
#ifdef SAVE_OUTPUTS
// Creation of the outputs directory
struct stat st {};
if (stat("outputs", &st) == -1) {
mkdir("outputs", 0700);
}
#endif
{%- for action in actions %}
{{ action }}
{%- endfor %}
{%- for output_name in outputs_name %}
*{{ output_name }}_ptr = {{ output_name }};
{%- endfor %}
}
......@@ -10,9 +10,13 @@ extern "C" {
#include <{{ lib }}>
{%- endfor %}
{% for func in functions %}
{{ func }};
{% endfor %}
void {{ func_name }} (
{%- for i in range(inputs_name | length) %}
const {{ inputs_dtype[i] }}* {{ inputs_name[i] }},
{% endfor -%}
{%- for o in range(outputs_name | length) %}
{{ outputs_dtype[o] }}** {{ outputs_name[o] }}{% if not loop.last %}, {% endif %}
{% endfor -%});
#ifdef __cplusplus
}
......
......@@ -22,11 +22,11 @@ int main()
{
// Initialize the output arrays
{%- for o in range(outputs_name | length) %}
{{ outputs_dtype[o] }} {{ outputs_name[o] }}[{{ outputs_size[o] }}];
{{ outputs_dtype[o] }}* {{ outputs_name[o] }} = nullptr;
{% endfor %}
// Call the forward function
{{ func_name }}({{ inputs_name|join(", ") }}, {{ outputs_name|join(", ") }});
{{ func_name }}({{ inputs_name|join(", ") }}, &{{ outputs_name|join(", &") }});
// Print the results of each output
{%- for o in range(outputs_name | length) %}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment