Skip to content
Snippets Groups Projects
Commit c788987e authored by Axel Farrugia's avatar Axel Farrugia
Browse files

[Refactor] Change the parameters template to improve readability

parent ad1bf296
No related branches found
No related tags found
2 merge requests!710.4.0,!58Fix the aidge_cmp feature
......@@ -2,29 +2,12 @@ import os
from pathlib import Path
import numpy as np
import aidge_core
from aidge_core.export_utils import ExportNodeCpp, generate_file
from aidge_core.export_utils import ExportNodeCpp, generate_file, aidge2c
from aidge_export_cpp import ROOT
from aidge_export_cpp import ExportLibCpp
def numpy_dtype2ctype(dtype):
if dtype == np.int8:
return "int8_t"
elif dtype == np.int16:
return "int16_t"
elif dtype == np.int32:
return "int32_t"
elif dtype == np.int64:
return "int64_t"
elif dtype == np.float32:
return "float"
elif dtype == np.float64:
return "double"
# Add more dtype mappings as needed
else:
raise ValueError(f"Unsupported {dtype} dtype")
def export_params(name: str,
array: np.ndarray,
output: aidge_core.Tensor,
filepath: str):
# Get directory name of the file
......@@ -38,15 +21,16 @@ def export_params(name: str,
filepath,
str(ROOT / "templates" / "data" / "parameters.jinja"),
name=name,
data_t=numpy_dtype2ctype(array.dtype),
values=array.tolist()
dims=output.dims(),
dtype=aidge2c(output.dtype()),
values=np.array(output).tolist()
)
@ExportLibCpp.register("Producer", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.any)))
class ProducerCPP(ExportNodeCpp):
def __init__(self, node, mem_info):
super().__init__(node, mem_info)
self.values = np.array(self.operator.get_output(0))
self.output = self.operator.get_output(0)
self.ignore = node.attributes().has_attr("ignore")
def export(self, export_folder: Path):
......@@ -67,7 +51,7 @@ class ProducerCPP(ExportNodeCpp):
header_path = f"include/parameters/{self.attributes['name']}.h"
export_params(
self.attributes['out_name'][0],
self.values.reshape(-1),
self.output,
str(export_folder / header_path))
return [path_to_definition, header_path]
......
{#- For libraries -#}
#include <stdint.h>
{%- set format_map = {
"int8_t": "%4d",
"int32_t": "%6d",
"float": "%8.3f"
} %}
{# Design header of the array -#}
static const {{ data_t }} {{ name }}[{{ values|length }}] __attribute__((section(".nn_data"))) =
static const {{ dtype }} {{ name }}[{{ dims | join("*") }}] __attribute__((section(".nn_data"))) =
{
{# For loop to add new elements -#}
{%- for i in range(values|length) %}
{# 1D #}
{%- if dims | length == 1 -%}
{%- for x in range(dims[0]) -%}
{{ format_map[dtype] | format(values[x]) }},
{%- endfor -%}
{%- endif -%}
{#- 2D #}
{%- if dims | length == 2 -%}
{%- for y in range(dims[0]) %}
{{ ' ' }}
{%- for x in range(dims[1]) -%}
{{ format_map[dtype] | format(values[y][x]) }},
{%- endfor %}
{%- endfor -%}
{%- endif -%}
{#- 3D #}
{%- if dims | length == 3 -%}
{%- for z in range(dims[0]) %}
{{ ' ' }}
{%- for y in range(dims[1]) %}
{{ ' ' }}
{%- for x in range(dims[2]) -%}
{{ format_map[dtype] | format(values[z][y][x]) }},
{%- endfor -%}
{%- endfor %}
{%- endfor -%}
{%- endif -%}
{#- Last value -#}
{%- if (i+1) == values|length -%}
{{ values[i]|string }}
{%- else -%}
{%- if (i+1) % 5 == 0 -%}
{{ values[i]|string + ",\n\t" }}
{%- else -%}
{{ values[i]|string + ", " }}
{%- endif -%}
{%- endif -%}
{#- 4D #}
{%- if dims | length == 4 -%}
{%- for n in range(dims[0]) %}
{{ ' ' }}
{%- for z in range(dims[1]) %}
{{ ' ' }}
{%- for y in range(dims[2]) %}
{{ ' ' }}
{%- for x in range(dims[3]) -%}
{{ format_map[dtype] | format(values[n][z][y][x]) }},
{%- endfor -%}
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endif %}
};
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment