Skip to content
Snippets Groups Projects
Commit 034f98a5 authored by Vincent Templier's avatar Vincent Templier
Browse files

Update export and operator files with NodeExport

parent b2eb8a11
No related branches found
No related tags found
No related merge requests found
......@@ -45,15 +45,19 @@ def export(export_folder, graphview, scheduler):
list_configs = []
list_forward_nodes = scheduler.get_static_scheduling()
list_op = {}
for node in list_forward_nodes:
for node in graphview.get_nodes():
if node.type() in supported_operators():
op = EXPORT_CPP_REGISTRY[node.type()](node)
list_op[node.name()] = op
else:
continue
list_configs = op.export(dnn_folder, list_configs)
list_actions = op.forward(list_actions)
for node in list_forward_nodes:
list_actions = list_op[node.name()].forward(list_actions)
# Memory management
......
......@@ -241,9 +241,8 @@ class ReLUCPP(ExportNode):
def __init__(self, node):
super().__init__(node)
self.input_dims = self.input.get_operator().output(0).dims()
self.nb_data = 1
for i in self.input_dims:
for i in self.inputs_dims[0]:
self.nb_data *= i
def export(self, export_folder:str, list_configs:list):
......@@ -270,7 +269,7 @@ class ReLUCPP(ExportNode):
list_actions.append(generate_action(
KERNELS_FORWARD.ACTIVATION,
name=self.name,
input_name=self.input[0].name(),
input_name=self.inputs[0].name(),
output_name=self.name
))
return list_actions
......@@ -367,6 +366,14 @@ class FcCPP(ExportNode):
def __init__(self, node):
super().__init__(node)
if len(self.inputs_dims[0]) == 2:
self.inputs_dims[0] = [self.inputs_dims[0][1], 1, 1]
elif len(self.inputs_dims[0]) == 4:
self.inputs_dims[0] = self.inputs_dims[0][1:]
if len(self.outputs_dims[0]) == 2:
self.outputs_dims[0] = [self.outputs_dims[0][1], 1, 1]
def export(self, export_folder:str, list_configs:list):
copyfile(KERNELS.FC, f"{export_folder}/include/kernels/")
......@@ -382,6 +389,10 @@ class FcCPP(ExportNode):
output_dims=self.outputs_dims[0],
activation="Linear",
rescaling="NoScaling")
print(self.name)
print(self.inputs_dims[0])
print(self.outputs_dims[0])
# TODO : replace this by producer node export !
# for i in range(len(self.parameters)):
......@@ -408,3 +419,23 @@ class FcCPP(ExportNode):
biases_name=self.inputs[2].name()
))
return list_actions
@export_cpp_register("Producer")
class ProducerCPP(ExportNode):
def __init__(self, node):
super().__init__(node)
self.values = np.array(self.operator.output(0))
def export(self, export_folder:str, list_configs:list):
list_configs.append(f"parameters/{self.name}.h")
export_to_static(self.name,
self.values.reshape(-1),
f"{export_folder}/parameters/{self.name}.h")
return list_configs
def forward(self, list_actions:list):
return list_actions
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment