From a5e8976dc1e80dd06fc3c21fa53054fd41038031 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Tue, 23 Jul 2024 09:24:46 +0000 Subject: [PATCH] Fix multiple typo due to old naming. --- aidge_core/export_utils/scheduler_export.py | 21 +++++++++++-------- .../templates/forward_header.jinja | 2 +- aidge_core/export_utils/templates/main.jinja | 2 +- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/aidge_core/export_utils/scheduler_export.py b/aidge_core/export_utils/scheduler_export.py index be7815f3b..21e228c93 100644 --- a/aidge_core/export_utils/scheduler_export.py +++ b/aidge_core/export_utils/scheduler_export.py @@ -26,6 +26,9 @@ class ExportScheduler(): os.makedirs(str(dnn_folder), exist_ok=True) if memory_manager_args is None: memory_manager_args = {} + + if memory_manager is None: + raise ValueError("A memory manager is required (no default value yet).") peak_mem, mem_info = memory_manager( self.scheduler, **memory_manager_args) @@ -46,7 +49,6 @@ class ExportScheduler(): if export_lib is not None: for node in list_forward_nodes: if export_lib.exportable(node): - is_input = node in self.graphview.get_input_nodes() is_output = node in self.graphview.get_output_nodes() op = export_lib.get_export_node(node)( @@ -56,10 +58,12 @@ class ExportScheduler(): # For forward file list_actions = op.forward(list_actions) if is_input: - for idx in range(len(node.inputs())): - inputs_name.append(op.attributes["in_name"][idx]) - inputs_dtype.append( - op.attributes["in_cdtype"][idx]) + for idx, node in enumerate(node.inputs()): + if node[0] not in self.graphview.get_nodes(): + inputs_name.append(op.attributes["in_name"][idx]) + inputs_dtype.append( + op.attributes["in_cdtype"][idx] + ) if is_output: for idx in range(len(node.outputs())): outputs_name.append(op.attributes["out_name"][idx]) @@ -73,14 +77,13 @@ class ExportScheduler(): raise ValueError("Current export only support export lib.") func_name = "model_forward" - args = ", ".join([f"{dtype} const {name}*" for name, + args = ", ".join([f"const {dtype}* {name}" for name, dtype in zip(inputs_name, inputs_dtype)]) - args += ", ".join([f"{dtype} {name}*" for name, + args += ", " +", ".join([f"{dtype}* {name}" for name, dtype in zip(outputs_name, outputs_dtype)]) - forward_func = f"void {func_name}()" + forward_func = f"void {func_name}({args})" ROOT = Path(__file__).resolve().parents[0] - generate_file( str(dnn_folder / "src" / "forward.cpp"), str(ROOT / "templates" / "forward.jinja"), diff --git a/aidge_core/export_utils/templates/forward_header.jinja b/aidge_core/export_utils/templates/forward_header.jinja index ad9f77441..1e14cdea6 100644 --- a/aidge_core/export_utils/templates/forward_header.jinja +++ b/aidge_core/export_utils/templates/forward_header.jinja @@ -1,4 +1,4 @@ -$#ifndef DNN_HPP +#ifndef DNN_HPP #define DNN_HPP {#- For libraries #} diff --git a/aidge_core/export_utils/templates/main.jinja b/aidge_core/export_utils/templates/main.jinja index bbe0df8cb..06945222f 100644 --- a/aidge_core/export_utils/templates/main.jinja +++ b/aidge_core/export_utils/templates/main.jinja @@ -1,6 +1,6 @@ #include <iostream> -#include "dnn.hpp" +#include "forward.hpp" #include "inputs.h" int main() -- GitLab