Skip to content
Snippets Groups Projects
Commit 33fd590a authored by Vincent Templier's avatar Vincent Templier
Browse files

Change module entry point

parent 09d81703
Branches
Tags
2 merge requests!90.1.1,!8Refactor module
import re
from . import register
from . import operators
r"""
Aidge Export for CPP standalone projects
from .register import *
from .operators import *
Use this module to generate CPP generic exports.
This module has to be used with the Aidge suite
"""
from aidge_export_cpp.utils import ROOT
def get_functions_from_c_file(file_path):
functions = []
pattern = r'\w+\s+(\w+)\s*\(([^)]*)\)\s*{'
__version__ = open(ROOT / "version.txt", "r").read().strip()
with open(file_path, 'r') as file:
file_content = file.read()
matches = re.findall(pattern, file_content)
for match in matches:
function_name = match[0]
arguments = match[1].split(',')
arguments = [arg.strip() for arg in arguments]
return_type = get_return_type(file_content, function_name)
function_string = f"{return_type} {function_name}({', '.join(arguments)});"
functions.append(function_string)
return functions
def get_return_type(file_content, function_name):
pattern = rf'\w+\s+{function_name}\s*\([^)]*\)\s*{{'
return_type = re.search(pattern, file_content).group()
return_type = return_type.split()[0].strip()
return return_type
def export(export_folder, graphview, scheduler):
os.makedirs(export_folder, exist_ok=True)
dnn_folder = export_folder + "/dnn"
os.makedirs(dnn_folder, exist_ok=True)
list_actions = []
list_configs = []
list_forward_nodes = [i for i in scheduler.get_static_scheduling() if i.type() != "Producer"]
list_op = {}
for node in graphview.get_nodes():
if node.type() in supported_operators():
op = EXPORT_CPP_REGISTRY[node.type()](node)
list_op[node.name()] = op
else:
continue
list_configs = op.export(dnn_folder, list_configs)
for node in list_forward_nodes:
list_actions = list_op[node.name()].forward(list_actions)
# Memory management
mem_offsets = []
mem_size = 0
for i, node in enumerate(list_forward_nodes):
if i != len(list_forward_nodes) - 1:
mem_offsets.append(f"{node.name().upper()}_OFFSET {mem_size}")
dims = node.get_operator().get_output(0).dims()
mem = 1
for dim in dims:
mem *= dim
mem_size += mem
# Generate the memory file
generate_file(
f"{dnn_folder}/memory/mem_info.h",
dirpath + "/templates/memory/mem_info.jinja",
mem_size=mem_size,
offsets=mem_offsets
)
list_configs.append("memory/mem_info.h")
generate_file(
f"{dnn_folder}/src/forward.cpp",
dirpath + "/templates/network/network_forward.jinja",
headers=list_configs,
actions=list_actions,
input_t="float",
inputs= list_forward_nodes[0].name()+"_input" if list_forward_nodes[0].get_parents()[0] is None else list_forward_nodes[0].get_parents()[0].name(),
output_t="float",
outputs=list_forward_nodes[-1].name()
)
generate_file(
f"{dnn_folder}/include/dnn.hpp",
dirpath + "/templates/network/dnn_header.jinja",
libraries=[],
functions=get_functions_from_c_file(f"{dnn_folder}/src/forward.cpp"),
)
# Copy all static files in the export
shutil.copy(dirpath + "/static/main.cpp", export_folder)
shutil.copy(dirpath + "/static/Makefile", export_folder)
shutil.copytree(dirpath + "/static/include", dnn_folder + "/include/", dirs_exist_ok=True)
from .export import *
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment