Skip to content
Snippets Groups Projects
Commit 034f98a5 authored by Vincent Templier's avatar Vincent Templier
Browse files

Update export and operator files with NodeExport

parent b2eb8a11
No related branches found
No related tags found
1 merge request!1Add Node export for Cpp export
...@@ -45,15 +45,19 @@ def export(export_folder, graphview, scheduler): ...@@ -45,15 +45,19 @@ def export(export_folder, graphview, scheduler):
list_configs = [] list_configs = []
list_forward_nodes = scheduler.get_static_scheduling() list_forward_nodes = scheduler.get_static_scheduling()
list_op = {}
for node in list_forward_nodes: for node in graphview.get_nodes():
if node.type() in supported_operators(): if node.type() in supported_operators():
op = EXPORT_CPP_REGISTRY[node.type()](node) op = EXPORT_CPP_REGISTRY[node.type()](node)
list_op[node.name()] = op
else: else:
continue continue
list_configs = op.export(dnn_folder, list_configs) list_configs = op.export(dnn_folder, list_configs)
list_actions = op.forward(list_actions)
for node in list_forward_nodes:
list_actions = list_op[node.name()].forward(list_actions)
# Memory management # Memory management
......
...@@ -241,9 +241,8 @@ class ReLUCPP(ExportNode): ...@@ -241,9 +241,8 @@ class ReLUCPP(ExportNode):
def __init__(self, node): def __init__(self, node):
super().__init__(node) super().__init__(node)
self.input_dims = self.input.get_operator().output(0).dims()
self.nb_data = 1 self.nb_data = 1
for i in self.input_dims: for i in self.inputs_dims[0]:
self.nb_data *= i self.nb_data *= i
def export(self, export_folder:str, list_configs:list): def export(self, export_folder:str, list_configs:list):
...@@ -270,7 +269,7 @@ class ReLUCPP(ExportNode): ...@@ -270,7 +269,7 @@ class ReLUCPP(ExportNode):
list_actions.append(generate_action( list_actions.append(generate_action(
KERNELS_FORWARD.ACTIVATION, KERNELS_FORWARD.ACTIVATION,
name=self.name, name=self.name,
input_name=self.input[0].name(), input_name=self.inputs[0].name(),
output_name=self.name output_name=self.name
)) ))
return list_actions return list_actions
...@@ -367,6 +366,14 @@ class FcCPP(ExportNode): ...@@ -367,6 +366,14 @@ class FcCPP(ExportNode):
def __init__(self, node): def __init__(self, node):
super().__init__(node) super().__init__(node)
if len(self.inputs_dims[0]) == 2:
self.inputs_dims[0] = [self.inputs_dims[0][1], 1, 1]
elif len(self.inputs_dims[0]) == 4:
self.inputs_dims[0] = self.inputs_dims[0][1:]
if len(self.outputs_dims[0]) == 2:
self.outputs_dims[0] = [self.outputs_dims[0][1], 1, 1]
def export(self, export_folder:str, list_configs:list): def export(self, export_folder:str, list_configs:list):
copyfile(KERNELS.FC, f"{export_folder}/include/kernels/") copyfile(KERNELS.FC, f"{export_folder}/include/kernels/")
...@@ -382,6 +389,10 @@ class FcCPP(ExportNode): ...@@ -382,6 +389,10 @@ class FcCPP(ExportNode):
output_dims=self.outputs_dims[0], output_dims=self.outputs_dims[0],
activation="Linear", activation="Linear",
rescaling="NoScaling") rescaling="NoScaling")
print(self.name)
print(self.inputs_dims[0])
print(self.outputs_dims[0])
# TODO : replace this by producer node export ! # TODO : replace this by producer node export !
# for i in range(len(self.parameters)): # for i in range(len(self.parameters)):
...@@ -408,3 +419,23 @@ class FcCPP(ExportNode): ...@@ -408,3 +419,23 @@ class FcCPP(ExportNode):
biases_name=self.inputs[2].name() biases_name=self.inputs[2].name()
)) ))
return list_actions return list_actions
@export_cpp_register("Producer")
class ProducerCPP(ExportNode):
def __init__(self, node):
super().__init__(node)
self.values = np.array(self.operator.output(0))
def export(self, export_folder:str, list_configs:list):
list_configs.append(f"parameters/{self.name}.h")
export_to_static(self.name,
self.values.reshape(-1),
f"{export_folder}/parameters/{self.name}.h")
return list_configs
def forward(self, list_actions:list):
return list_actions
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment