Skip to content
Snippets Groups Projects
Commit 9c313b20 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Export function now just call exportScheduler.

parent 9444564e
No related branches found
No related tags found
3 merge requests!27v0.2.0,!22v0.4.0,!15Export refactor
......@@ -39,93 +39,97 @@ def generate_input_file(export_folder:str,
def export(export_folder_name, graphview, scheduler, mem_wrapping=False):
es = aidge_core.export_utils.ExportScheduler(scheduler)
es.export(export_folder_name,
ExportLibCpp,
memory_manager=compute_default_mem_info)
export_folder = Path().absolute() / export_folder_name
# export_folder = Path().absolute() / export_folder_name
os.makedirs(str(export_folder), exist_ok=True)
# os.makedirs(str(export_folder), exist_ok=True)
dnn_folder = export_folder / "dnn"
os.makedirs(str(dnn_folder), exist_ok=True)
# dnn_folder = export_folder / "dnn"
# os.makedirs(str(dnn_folder), exist_ok=True)
list_actions = []
list_configs = []
peak_mem, mem_info = compute_default_mem_info(scheduler)
list_forward_nodes = scheduler.get_static_scheduling()
for node in list_forward_nodes:
if ExportLibCpp.exportable(node):
op = ExportLibCpp.get_export_node(node)(node, mem_info[node])
# For configuration files
list_configs = op.export(dnn_folder, list_configs)
# For forward file
list_actions = op.forward(list_actions)
else:
raise RuntimeError(f"Operator not supported: {node.type()} !")
# Memory management
# stats_folder = export_folder / "statistics"
# os.makedirs(str(stats_folder), exist_ok=True)
# mem_size, mem_info = generate_optimized_memory_info(stats_folder, scheduler, mem_wrapping)
# list_actions = []
# list_configs = []
# peak_mem, mem_info = compute_default_mem_info(scheduler)
# Generate the memory file
# list_forward_nodes = scheduler.get_static_scheduling()
# for node in list_forward_nodes:
# if ExportLibCpp.exportable(node):
# op = ExportLibCpp.get_export_node(node)(node, mem_info[node])
# # For configuration files
# list_configs = op.export(dnn_folder, list_configs)
# # For forward file
# list_actions = op.forward(list_actions)
# else:
# raise RuntimeError(f"Operator not supported: {node.type()} !")
# # Memory management
# # stats_folder = export_folder / "statistics"
# # os.makedirs(str(stats_folder), exist_ok=True)
# # mem_size, mem_info = generate_optimized_memory_info(stats_folder, scheduler, mem_wrapping)
# # peak_mem, mem_info = compute_default_mem_info(scheduler)
# # Generate the memory file
# # generate_file(
# # str(dnn_folder / "memory" / "mem_info.h"),
# # str(ROOT / "templates" / "memory" / "mem_info.jinja"),
# # mem_size = mem_size,
# # mem_info_legends = MEMORY_INFO_TEMPLATE,
# # mem_info = mem_info
# # )
# # list_configs.append("memory/mem_info.h")
# # Get entry nodes
# # Store the datatype & name
# list_inputs_name = []
# for node in graphview.get_input_nodes():
# for idx, node_input_tuple in enumerate(node.inputs()):
# node_input, _ = node_input_tuple
# if node_input is None:
# export_type = aidge2c(node.get_operator().get_output(0).dtype())
# list_inputs_name.append((export_type, f"{node.name()}_input_{idx}"))
# elif node_input not in graphview.get_nodes():
# export_type = aidge2c(node_input.get_operator().get_output(0).dtype())
# list_inputs_name.append((export_type, node_input.name()))
# # Get output nodes
# # Store the datatype & name, like entry nodes
# list_outputs_name = []
# for node in graphview.get_nodes():
# if len(node.get_children()) == 0:
# export_type = aidge2c(node.get_operator().get_output(0).dtype())
# list_outputs_name.append((export_type, f"{node.name()}_output_0"))
# # Generate forward file
# # TODO: for now the mem type is bound for all intermediate results, should change.
# # Note that we may have all inputs constants, hence select output type
# assert len(list_outputs_name) >= 1, f"TODO: requires some output to determine mem type"
# mem_ctype = list_outputs_name[0][0]
# generate_file(
# str(dnn_folder / "memory" / "mem_info.h"),
# str(ROOT / "templates" / "memory" / "mem_info.jinja"),
# mem_size = mem_size,
# mem_info_legends = MEMORY_INFO_TEMPLATE,
# mem_info = mem_info
# str(dnn_folder / "src" / "forward.cpp"),
# str(ROOT / "templates" / "network" / "network_forward.jinja"),
# headers=set(list_configs),
# actions=list_actions,
# inputs= list_inputs_name,
# outputs=list_outputs_name,
# mem_ctype=mem_ctype,
# peak_mem=peak_mem
# )
# list_configs.append("memory/mem_info.h")
# Get entry nodes
# Store the datatype & name
list_inputs_name = []
for node in graphview.get_input_nodes():
for idx, node_input_tuple in enumerate(node.inputs()):
node_input, _ = node_input_tuple
if node_input is None:
export_type = aidge2c(node.get_operator().get_output(0).dtype())
list_inputs_name.append((export_type, f"{node.name()}_input_{idx}"))
elif node_input not in graphview.get_nodes():
export_type = aidge2c(node_input.get_operator().get_output(0).dtype())
list_inputs_name.append((export_type, node_input.name()))
# Get output nodes
# Store the datatype & name, like entry nodes
list_outputs_name = []
for node in graphview.get_nodes():
if len(node.get_children()) == 0:
export_type = aidge2c(node.get_operator().get_output(0).dtype())
list_outputs_name.append((export_type, f"{node.name()}_output_0"))
# Generate forward file
# TODO: for now the mem type is bound for all intermediate results, should change.
# Note that we may have all inputs constants, hence select output type
assert len(list_outputs_name) >= 1, f"TODO: requires some output to determine mem type"
mem_ctype = list_outputs_name[0][0]
generate_file(
str(dnn_folder / "src" / "forward.cpp"),
str(ROOT / "templates" / "network" / "network_forward.jinja"),
headers=set(list_configs),
actions=list_actions,
inputs= list_inputs_name,
outputs=list_outputs_name,
mem_ctype=mem_ctype,
peak_mem=peak_mem
)
# Generate dnn API
generate_file(
str(dnn_folder / "include" / "dnn.hpp"),
str(ROOT / "templates" / "network" / "dnn_header.jinja"),
libraries=[],
functions=get_functions_from_c_file(str(dnn_folder / "src" / "forward.cpp")),
)
# # Generate dnn API
# generate_file(
# str(dnn_folder / "include" / "dnn.hpp"),
# str(ROOT / "templates" / "network" / "dnn_header.jinja"),
# libraries=[],
# functions=get_functions_from_c_file(str(dnn_folder / "src" / "forward.cpp")),
# )
# Copy all static files in the export
shutil.copy(str(ROOT / "static" / "main.cpp"), str(export_folder))
shutil.copy(str(ROOT / "static" / "Makefile"), str(export_folder))
shutil.copytree(str(ROOT / "static" / "include"), str(dnn_folder / "include"), dirs_exist_ok=True)
# # Copy all static files in the export
# shutil.copy(str(ROOT / "static" / "main.cpp"), str(export_folder))
# shutil.copy(str(ROOT / "static" / "Makefile"), str(export_folder))
# shutil.copytree(str(ROOT / "static" / "include"), str(dnn_folder / "include"), dirs_exist_ok=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment