Skip to content
Snippets Groups Projects
Commit a5e8976d authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Fix multiple typo due to old naming.

parent 987cbfc8
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!163Export refactor
......@@ -26,6 +26,9 @@ class ExportScheduler():
os.makedirs(str(dnn_folder), exist_ok=True)
if memory_manager_args is None:
memory_manager_args = {}
if memory_manager is None:
raise ValueError("A memory manager is required (no default value yet).")
peak_mem, mem_info = memory_manager(
self.scheduler, **memory_manager_args)
......@@ -46,7 +49,6 @@ class ExportScheduler():
if export_lib is not None:
for node in list_forward_nodes:
if export_lib.exportable(node):
is_input = node in self.graphview.get_input_nodes()
is_output = node in self.graphview.get_output_nodes()
op = export_lib.get_export_node(node)(
......@@ -56,10 +58,12 @@ class ExportScheduler():
# For forward file
list_actions = op.forward(list_actions)
if is_input:
for idx in range(len(node.inputs())):
inputs_name.append(op.attributes["in_name"][idx])
inputs_dtype.append(
op.attributes["in_cdtype"][idx])
for idx, node in enumerate(node.inputs()):
if node[0] not in self.graphview.get_nodes():
inputs_name.append(op.attributes["in_name"][idx])
inputs_dtype.append(
op.attributes["in_cdtype"][idx]
)
if is_output:
for idx in range(len(node.outputs())):
outputs_name.append(op.attributes["out_name"][idx])
......@@ -73,14 +77,13 @@ class ExportScheduler():
raise ValueError("Current export only support export lib.")
func_name = "model_forward"
args = ", ".join([f"{dtype} const {name}*" for name,
args = ", ".join([f"const {dtype}* {name}" for name,
dtype in zip(inputs_name, inputs_dtype)])
args += ", ".join([f"{dtype} {name}*" for name,
args += ", " +", ".join([f"{dtype}* {name}" for name,
dtype in zip(outputs_name, outputs_dtype)])
forward_func = f"void {func_name}()"
forward_func = f"void {func_name}({args})"
ROOT = Path(__file__).resolve().parents[0]
generate_file(
str(dnn_folder / "src" / "forward.cpp"),
str(ROOT / "templates" / "forward.jinja"),
......
$#ifndef DNN_HPP
#ifndef DNN_HPP
#define DNN_HPP
{#- For libraries #}
......
#include <iostream>
#include "dnn.hpp"
#include "forward.hpp"
#include "inputs.h"
int main()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment