diff --git a/aidge_core/export_utils/scheduler_export.py b/aidge_core/export_utils/scheduler_export.py index be7815f3bb057eba343b3296d34fb644a9ce81c4..21e228c93a06e933cbd7bfcb0ff3650be6760200 100644 --- a/aidge_core/export_utils/scheduler_export.py +++ b/aidge_core/export_utils/scheduler_export.py @@ -26,6 +26,9 @@ class ExportScheduler(): os.makedirs(str(dnn_folder), exist_ok=True) if memory_manager_args is None: memory_manager_args = {} + + if memory_manager is None: + raise ValueError("A memory manager is required (no default value yet).") peak_mem, mem_info = memory_manager( self.scheduler, **memory_manager_args) @@ -46,7 +49,6 @@ class ExportScheduler(): if export_lib is not None: for node in list_forward_nodes: if export_lib.exportable(node): - is_input = node in self.graphview.get_input_nodes() is_output = node in self.graphview.get_output_nodes() op = export_lib.get_export_node(node)( @@ -56,10 +58,12 @@ class ExportScheduler(): # For forward file list_actions = op.forward(list_actions) if is_input: - for idx in range(len(node.inputs())): - inputs_name.append(op.attributes["in_name"][idx]) - inputs_dtype.append( - op.attributes["in_cdtype"][idx]) + for idx, node in enumerate(node.inputs()): + if node[0] not in self.graphview.get_nodes(): + inputs_name.append(op.attributes["in_name"][idx]) + inputs_dtype.append( + op.attributes["in_cdtype"][idx] + ) if is_output: for idx in range(len(node.outputs())): outputs_name.append(op.attributes["out_name"][idx]) @@ -73,14 +77,13 @@ class ExportScheduler(): raise ValueError("Current export only support export lib.") func_name = "model_forward" - args = ", ".join([f"{dtype} const {name}*" for name, + args = ", ".join([f"const {dtype}* {name}" for name, dtype in zip(inputs_name, inputs_dtype)]) - args += ", ".join([f"{dtype} {name}*" for name, + args += ", " +", ".join([f"{dtype}* {name}" for name, dtype in zip(outputs_name, outputs_dtype)]) - forward_func = f"void {func_name}()" + forward_func = f"void {func_name}({args})" ROOT = Path(__file__).resolve().parents[0] - generate_file( str(dnn_folder / "src" / "forward.cpp"), str(ROOT / "templates" / "forward.jinja"), diff --git a/aidge_core/export_utils/templates/forward_header.jinja b/aidge_core/export_utils/templates/forward_header.jinja index ad9f7744149d39a8b0fcd25eca7e496790f63f46..1e14cdea67f0e9da930c0907fa1d82624a8579e4 100644 --- a/aidge_core/export_utils/templates/forward_header.jinja +++ b/aidge_core/export_utils/templates/forward_header.jinja @@ -1,4 +1,4 @@ -$#ifndef DNN_HPP +#ifndef DNN_HPP #define DNN_HPP {#- For libraries #} diff --git a/aidge_core/export_utils/templates/main.jinja b/aidge_core/export_utils/templates/main.jinja index bbe0df8cb891e410372c60038c7f8cffcc14eabb..06945222fcfc5a3d11852d66e2706bf1881eeb79 100644 --- a/aidge_core/export_utils/templates/main.jinja +++ b/aidge_core/export_utils/templates/main.jinja @@ -1,6 +1,6 @@ #include <iostream> -#include "dnn.hpp" +#include "forward.hpp" #include "inputs.h" int main()