Skip to content
Snippets Groups Projects
Commit 3cc3f05b authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Add docstring to export functions.

parent 697e692b
No related branches found
No related tags found
No related merge requests found
import re
import os
import shutil
from aidge_export_arm_cortexm.utils import (ROOT, AVAILABLE_BOARDS, has_board, \
from aidge_export_arm_cortexm.utils import (ROOT, AVAILABLE_BOARDS, has_board,
OPERATORS_REGISTRY, supported_operators)
import aidge_export_arm_cortexm.operators
from aidge_export_arm_cortexm.utils.scheduler import topological_sort
from aidge_export_arm_cortexm.utils.generation import get_functions_from_c_file, get_functions_from_c_folder
import aidge_core # Used for type hint
@aidge_core.utils.template_docstring("available_board", ", ".join(AVAILABLE_BOARDS.keys()))
def export(export_folder: str,
graphview: aidge_core.GraphView,
scheduler: aidge_core.Scheduler = None,
board: str = "stm32h7") -> None:
"""Generate a STM32 export of an :py:class:`aidge_core.GraphView`.
:param export_folder: Location of the export to generate
:type export_folder: str
:param graphview: GraphView to export
:type graphview: :py:class:`aidge_core.GraphView`
:param scheduler: An aidge scheduler that will provide the ordering of the nodes, if ``None``, a topological ordering is performed, defaults to None
:type scheduler: :py:class:`aidge_core.Scheduler`, optional
:param board: String describing the type of board you want to export to, can be one of these values: [{available_board}], defaults to "stm32h7"
:type board: str, optional
"""
def export(export_folder,
graphview,
scheduler: list = None,
board: str ="stm32h7"):
# Create export directory
os.makedirs(export_folder, exist_ok=True)
......@@ -26,8 +38,9 @@ def export(export_folder,
if has_board(board):
board_path = AVAILABLE_BOARDS[board]
else:
raise ValueError(f"{board} not found in the package. Please among those boards: {list(AVAILABLE_BOARDS.keys())}")
raise ValueError(
f"{board} not found in the package. Please among those boards: {list(AVAILABLE_BOARDS.keys())}")
# Copy all static files in the export
shutil.copytree(board_path, export_folder, dirs_exist_ok=True)
......@@ -40,7 +53,6 @@ def export(export_folder,
# Not tested...
list_forward_nodes = scheduler.get_static_scheduling()
# Set some lists of elements for generating forward file
list_actions = []
list_configs = []
......@@ -48,7 +60,8 @@ def export(export_folder,
# Export layer configurations
for node in list_forward_nodes:
if node.type() in supported_operators():
op = OPERATORS_REGISTRY[node.type()](node, board, dataformat="float32", library="aidge")
op = OPERATORS_REGISTRY[node.type()](
node, board, dataformat="float32", library="aidge")
# Export the configuration
list_configs = op.export(dnn_folder, list_configs)
......@@ -114,10 +127,3 @@ def export(export_folder,
libraries=[],
functions=get_functions_from_c_file(f"{dnn_folder}/src/forward.c"),
)
......@@ -12,8 +12,15 @@ from aidge_export_arm_cortexm.utils import ROOT, operator_register
############## Export functions ##############
##############################################
def generate_file(filename, templatename, **kwargs):
def generate_file(filename: str, templatename: str, **kwargs):
"""Generate a file using a Jinja template
:param filename: Path of the file to create
:type filename: str
:param templatename: Path of the template to use for the forward
:type templatename: str
:param \**kwargs: kwargs are passed to Jinja for template generation.
"""
# Get directory name of the file
dirname = os.path.dirname(filename)
......@@ -110,14 +117,14 @@ class Add(ExportNode):
self.dataformat = dataformat
def export(self, export_folder:str, list_configs:list):
# Copying kernel into export
# Find a more generic system for future dev
if self.library == "aidge":
if self.dataformat == "float32":
copyfile(str(ROOT / "kernels" / "ElemWise" / "Add" / "aidge_add_float32.c"),
str(Path(export_folder) / "src" / "kernels"))
# Add to config list the include of configurations
list_configs.append(f"layers/{self.name}.h")
......@@ -129,7 +136,7 @@ class Add(ExportNode):
elemwise_op="\"ADD\"",
nb_inputs=np.prod(self.inputs_dims[0]),
nb_outputs=np.prod(self.outputs_dims[0]))
return list_configs
def forward(self, list_actions:list):
......@@ -160,14 +167,14 @@ class Sub(ExportNode):
self.dataformat = dataformat
def export(self, export_folder:str, list_configs:list):
# Copying kernel into export
# Find a more generic system for future dev
if self.library == "aidge":
if self.dataformat == "float32":
copyfile(str(ROOT / "kernels" / "ElemWise" / "Sub" / "aidge_sub_float32.c"),
str(Path(export_folder) / "src" / "kernels"))
# Add to config list the include of configurations
list_configs.append(f"layers/{self.name}.h")
......@@ -179,7 +186,7 @@ class Sub(ExportNode):
elemwise_op="\"SUB\"",
nb_inputs=np.prod(self.inputs_dims[0]),
nb_outputs=np.prod(self.outputs_dims[0]))
return list_configs
def forward(self, list_actions:list):
......@@ -210,14 +217,14 @@ class Mul(ExportNode):
self.dataformat = dataformat
def export(self, export_folder:str, list_configs:list):
# Copying kernel into export
# Find a more generic system for future dev
if self.library == "aidge":
if self.dataformat == "float32":
copyfile(str(ROOT / "kernels" / "ElemWise" / "Mul" / "aidge_mul_float32.c"),
str(Path(export_folder) / "src" / "kernels"))
# Add to config list the include of configurations
list_configs.append(f"layers/{self.name}.h")
......@@ -229,7 +236,7 @@ class Mul(ExportNode):
elemwise_op="\"MUL\"",
nb_inputs=np.prod(self.inputs_dims[0]),
nb_outputs=np.prod(self.outputs_dims[0]))
return list_configs
def forward(self, list_actions:list):
......@@ -260,14 +267,14 @@ class Div(ExportNode):
self.dataformat = dataformat
def export(self, export_folder:str, list_configs:list):
# Copying kernel into export
# Find a more generic system for future dev
if self.library == "aidge":
if self.dataformat == "float32":
copyfile(str(ROOT / "kernels" / "ElemWise" / "Div" / "aidge_div_float32.c"),
str(Path(export_folder) / "src" / "kernels"))
# Add to config list the include of configurations
list_configs.append(f"layers/{self.name}.h")
......@@ -279,7 +286,7 @@ class Div(ExportNode):
elemwise_op="\"DIV\"",
nb_inputs=np.prod(self.inputs_dims[0]),
nb_outputs=np.prod(self.outputs_dims[0]))
return list_configs
def forward(self, list_actions:list):
......@@ -318,7 +325,7 @@ class Gemm(ExportNode):
if self.dataformat == "float32":
copyfile(str(ROOT / "kernels" / "FullyConnected" / "aidge_fc_float32.c"),
str(Path(export_folder) / "src" / "kernels"))
# Add to config list the include of configurations
list_configs.append(f"layers/{self.name}.h")
......@@ -330,7 +337,7 @@ class Gemm(ExportNode):
nb_channels=self.inputs_dims[0][0],
nb_outputs=self.outputs_dims[0][0],
biases_size=self.outputs_dims[0][0])
return list_configs
def forward(self, list_actions:list):
......@@ -368,7 +375,7 @@ class Atan(ExportNode):
if self.dataformat == "float32":
copyfile(str(ROOT / "kernels" / "Activation" / "Atan" / "aidge_atan_float32.c"),
str(Path(export_folder) / "src" / "kernels"))
# Add to config list the include of configurations
list_configs.append(f"layers/{self.name}.h")
......@@ -380,7 +387,7 @@ class Atan(ExportNode):
activation_type="\"ATAN\"",
nb_inputs=np.prod(self.inputs_dims[0]),
nb_outputs=np.prod(self.outputs_dims[0]))
return list_configs
def forward(self, list_actions:list):
......@@ -424,7 +431,7 @@ class Slice(ExportNode):
if self.dataformat == "float32":
copyfile(str(ROOT / "kernels" / "Slice" / "aidge_slice_float32.c"),
str(Path(export_folder) / "src" / "kernels"))
# Add to config list the include of configurations
list_configs.append(f"layers/{self.name}.h")
......@@ -438,7 +445,7 @@ class Slice(ExportNode):
ends=self.ends,
nb_inputs=np.prod(self.inputs_dims[0]),
nb_outputs=np.prod(self.outputs_dims[0]))
return list_configs
def forward(self, list_actions:list):
......@@ -474,7 +481,7 @@ class Concat(ExportNode):
self.board = board
self.library = library
self.dataformat = dataformat
def export(self, export_folder:str, list_configs:list):
# Copying kernel into export
......@@ -483,7 +490,7 @@ class Concat(ExportNode):
if self.dataformat == "float32":
copyfile(str(ROOT / "kernels" / "Concat" / "aidge_concat_float32.c"),
str(Path(export_folder) / "src" / "kernels"))
# Add to config list the include of configurations
list_configs.append(f"layers/{self.name}.h")
......@@ -502,7 +509,7 @@ class Concat(ExportNode):
list_input_size=list_input_size,
output_size=np.sum(list_input_size)
)
return list_configs
def forward(self, list_actions:list):
......
......@@ -6,7 +6,12 @@ FILE = Path(__file__).resolve()
ROOT = FILE.parents[1]
def get_all_available_boards():
def get_all_available_boards() -> map:
"""Return a map (key: board name, value: path to folder) of all boards available in the ``/boards/`` folder.
:return: Map (key: board name, value: path to folder) of all boards available in the export
:rtype: map[str, str]
"""
boards = {}
directory_path = Path(str(ROOT / "boards"))
......@@ -24,26 +29,35 @@ def get_all_available_boards():
board_name = relpath.replace('/', '').replace('\\', '')
boards[board_name.lower()] = str(subfolder)
return boards
AVAILABLE_BOARDS = get_all_available_boards()
def has_board(board_name: str) -> bool:
"""This function is not case sensitive.
:param board_name: Board name to check if available in the export
:type board_name: str
:return: If ``board_name`` is in ``AVAILABLE_BAORDS``, return True else, return False.
:rtype: bool
"""
return board_name.lower() in AVAILABLE_BOARDS.keys()
OPERATORS_REGISTRY = {}
def operator_register(*args):
"""Decorator to register a function to export an operator
"""
key_list = [arg for arg in args]
def decorator(operator):
def wrapper(*args, **kwargs):
return operator(*args, **kwargs)
for key in key_list:
OPERATORS_REGISTRY[key] = operator
......@@ -51,4 +65,8 @@ def operator_register(*args):
return decorator
def supported_operators():
return list(OPERATORS_REGISTRY.keys())
\ No newline at end of file
"""
:return: List of operators supported by the export.
:rtype: list(str)
"""
return list(OPERATORS_REGISTRY.keys())
import re
import os
def get_functions_from_c_file(file_path):
def get_functions_from_c_file(file_path: str) -> list:
"""Return the code content of every function contained in the file
:param file_path: Path of the file you want to extract the function from.
:type file_path: str
:return: List of function content
:rtype: list[str]
"""
functions = []
pattern = r'\w+\s+(\w+)\s*\(([^)]*)\)\s*{'
......@@ -21,18 +28,35 @@ def get_functions_from_c_file(file_path):
return functions
def get_return_type(file_content, function_name):
def get_return_type(file_content:str, function_name:str) -> str:
"""Given a file content and a function name, retrieve the return type of the function.
:param file_content: String containing the C code of a function.
:type file_content: str
:param function_name: name of the function
:type function_name: str
:return: The return type in C format.
:rtype: str
"""
pattern = rf'\w+\s+{function_name}\s*\([^)]*\)\s*{{'
return_type = re.search(pattern, file_content).group()
return_type = return_type.split()[0].strip()
return return_type
def get_functions_from_c_folder(folder_path):
def get_functions_from_c_folder(folder_path: str) -> list:
"""Retrieve all the C functions defined in a folder.
This function will return a list of string containing the C functions.
:param folder_path: Path of the folder from which to extract the functions content
:type folder_path: str
:return: _description_
:rtype: _type_
"""
functions = []
for _, _, files in os.walk(folder_path):
for file in files:
functions += get_functions_from_c_file(os.path.join(folder_path, file))
return functions
\ No newline at end of file
return functions
def topological_sort(graphview):
"""Take an Aidge Graphview
and returns a list of nodes topologically sorting
"""Take an Aidge Graphview and returns a list of nodes topologically sorting
"""
nodes = graphview.get_nodes()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment