Skip to content
Snippets Groups Projects
Commit a5e8976d authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Fix multiple typo due to old naming.

parent 987cbfc8
No related branches found
No related tags found
No related merge requests found
...@@ -26,6 +26,9 @@ class ExportScheduler(): ...@@ -26,6 +26,9 @@ class ExportScheduler():
os.makedirs(str(dnn_folder), exist_ok=True) os.makedirs(str(dnn_folder), exist_ok=True)
if memory_manager_args is None: if memory_manager_args is None:
memory_manager_args = {} memory_manager_args = {}
if memory_manager is None:
raise ValueError("A memory manager is required (no default value yet).")
peak_mem, mem_info = memory_manager( peak_mem, mem_info = memory_manager(
self.scheduler, **memory_manager_args) self.scheduler, **memory_manager_args)
...@@ -46,7 +49,6 @@ class ExportScheduler(): ...@@ -46,7 +49,6 @@ class ExportScheduler():
if export_lib is not None: if export_lib is not None:
for node in list_forward_nodes: for node in list_forward_nodes:
if export_lib.exportable(node): if export_lib.exportable(node):
is_input = node in self.graphview.get_input_nodes() is_input = node in self.graphview.get_input_nodes()
is_output = node in self.graphview.get_output_nodes() is_output = node in self.graphview.get_output_nodes()
op = export_lib.get_export_node(node)( op = export_lib.get_export_node(node)(
...@@ -56,10 +58,12 @@ class ExportScheduler(): ...@@ -56,10 +58,12 @@ class ExportScheduler():
# For forward file # For forward file
list_actions = op.forward(list_actions) list_actions = op.forward(list_actions)
if is_input: if is_input:
for idx in range(len(node.inputs())): for idx, node in enumerate(node.inputs()):
inputs_name.append(op.attributes["in_name"][idx]) if node[0] not in self.graphview.get_nodes():
inputs_dtype.append( inputs_name.append(op.attributes["in_name"][idx])
op.attributes["in_cdtype"][idx]) inputs_dtype.append(
op.attributes["in_cdtype"][idx]
)
if is_output: if is_output:
for idx in range(len(node.outputs())): for idx in range(len(node.outputs())):
outputs_name.append(op.attributes["out_name"][idx]) outputs_name.append(op.attributes["out_name"][idx])
...@@ -73,14 +77,13 @@ class ExportScheduler(): ...@@ -73,14 +77,13 @@ class ExportScheduler():
raise ValueError("Current export only support export lib.") raise ValueError("Current export only support export lib.")
func_name = "model_forward" func_name = "model_forward"
args = ", ".join([f"{dtype} const {name}*" for name, args = ", ".join([f"const {dtype}* {name}" for name,
dtype in zip(inputs_name, inputs_dtype)]) dtype in zip(inputs_name, inputs_dtype)])
args += ", ".join([f"{dtype} {name}*" for name, args += ", " +", ".join([f"{dtype}* {name}" for name,
dtype in zip(outputs_name, outputs_dtype)]) dtype in zip(outputs_name, outputs_dtype)])
forward_func = f"void {func_name}()" forward_func = f"void {func_name}({args})"
ROOT = Path(__file__).resolve().parents[0] ROOT = Path(__file__).resolve().parents[0]
generate_file( generate_file(
str(dnn_folder / "src" / "forward.cpp"), str(dnn_folder / "src" / "forward.cpp"),
str(ROOT / "templates" / "forward.jinja"), str(ROOT / "templates" / "forward.jinja"),
......
$#ifndef DNN_HPP #ifndef DNN_HPP
#define DNN_HPP #define DNN_HPP
{#- For libraries #} {#- For libraries #}
......
#include <iostream> #include <iostream>
#include "dnn.hpp" #include "forward.hpp"
#include "inputs.h" #include "inputs.h"
int main() int main()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment