Skip to content
Snippets Groups Projects
Commit 9b5a4e86 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

[Fix] Resnet export.

parent 86c33bf4
Branches main
No related tags found
3 merge requests!27v0.2.0,!22v0.4.0,!15Export refactor
...@@ -82,7 +82,7 @@ def export(export_folder_name, graphview, scheduler): ...@@ -82,7 +82,7 @@ def export(export_folder_name, graphview, scheduler):
node_input, _ = node_input_tuple node_input, _ = node_input_tuple
if node_input is None: if node_input is None:
export_type = aidge2c(node.get_operator().get_output(0).dtype()) export_type = aidge2c(node.get_operator().get_output(0).dtype())
list_inputs_name.append((export_type, f"{node.name()}_{idx}")) list_inputs_name.append((export_type, f"{node.name()}_input_{idx}"))
elif node_input not in graphview.get_nodes(): elif node_input not in graphview.get_nodes():
export_type = aidge2c(node_input.get_operator().get_output(0).dtype()) export_type = aidge2c(node_input.get_operator().get_output(0).dtype())
list_inputs_name.append((export_type, node_input.name())) list_inputs_name.append((export_type, node_input.name()))
......
...@@ -404,6 +404,67 @@ class MaxPoolCPP(ExportNode): ...@@ -404,6 +404,67 @@ class MaxPoolCPP(ExportNode):
)) ))
return list_actions return list_actions
@operator_register("GlobalAveragePooling")
class GlobalAveragePoolCPP(ExportNode):
def __init__(self, node):
super().__init__(node)
self.stride = [1, 1]
# No padding with MaxPooling
# Use PaddedMaxPooling to add padding attribute
self.padding = [0, 0]
if len(self.inputs_dims[0]) == 4:
# if dims == [batch, nb_channels, height, width]
# transform to [nb_channels, height, width]
self.inputs_dims[0] = self.inputs_dims[0][1:]
self.kernel = self.inputs_dims[0][1:]
else:
raise RuntimeError("Input dims != 4 not supported.")
if len(self.outputs_dims[0]) == 4:
# if dims == [batch, nb_outputs]
# transform to [nb_outputs, 1, 1]
self.outputs_dims[0] = self.outputs_dims[0][1:]
elif len(self.outputs_dims[0]) == 2:
self.outputs_dims[0] = [self.outputs_dims[0][1], 1, 1]
def export(self, export_folder:Path, list_configs:list):
copyfile(str(ROOT / "kernels" / "pooling.hpp"),
str(export_folder / "include" / "kernels"))
list_configs.append("kernels/pooling.hpp")
list_configs.append(f"layers/{self.name}.h")
generate_file(
str(export_folder / "layers" / f"{self.name}.h"),
str(ROOT / "templates" / "configuration" / "pooling_config.jinja"),
name=self.name,
input_dims=self.inputs_dims[0],
output_dims=self.outputs_dims[0],
kernel=self.kernel,
stride=self.stride,
padding=self.padding,
pool_type="Average",
activation="Linear")
return list_configs
def forward(self, list_actions:list):
if not self.is_last:
list_actions.append(set_up_output(self.name, "float"))
list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja"),
name=self.name,
input_name=self.inputs[0].name() if self.inputs[0] else self.name + "_input",
output_name=self.name
))
return list_actions
@operator_register("FC") @operator_register("FC")
class FcCPP(ExportNode): class FcCPP(ExportNode):
def __init__(self, node): def __init__(self, node):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment