Skip to content
Snippets Groups Projects
Commit 3931178a authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Add PaddedMaxPooling support

parent 39d2c8d6
No related branches found
No related tags found
3 merge requests!27v0.2.0,!22v0.4.0,!15Export refactor
......@@ -289,6 +289,63 @@ class SubCPP(ExportNode):
))
return list_actions
@operator_register("PaddedMaxPooling")
class MaxPoolCPP(ExportNode):
def __init__(self, node):
super().__init__(node)
for n in self.operator.get_micro_graph().get_nodes():
if n.type() == "Pad":
self.padding = n.get_operator().get_attr("BeginEndBorders")
if n.type() == "MaxPooling":
self.kernel = n.get_operator().get_attr("KernelDims")
self.stride = n.get_operator().get_attr("StrideDims")
if len(self.inputs_dims[0]) == 4:
# if dims == [batch, nb_channels, height, width]
# transform to [nb_channels, height, width]
self.inputs_dims[0] = self.inputs_dims[0][1:]
if len(self.outputs_dims[0]) == 4:
# if dims == [batch, nb_outputs]
# transform to [nb_outputs, 1, 1]
self.outputs_dims[0] = self.outputs_dims[0][1:]
def export(self, export_folder:Path, list_configs:list):
copyfile(str(ROOT / "kernels" / "pooling.hpp"),
str(export_folder / "include" / "kernels"))
list_configs.append("kernels/pooling.hpp")
list_configs.append(f"layers/{self.name}.h")
generate_file(
str(export_folder / "layers" / f"{self.name}.h"),
str(ROOT / "templates" / "configuration" / "pooling_config.jinja"),
name=self.name,
input_dims=self.inputs_dims[0],
output_dims=self.outputs_dims[0],
kernel=self.kernel,
stride=self.stride,
padding=self.padding,
pool_type="Max",
activation="Linear")
return list_configs
def forward(self, list_actions:list):
if not self.is_last:
list_actions.append(set_up_output(self.name, "float"))
list_actions.append(generate_str(
str(ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja"),
name=self.name,
input_name=f"{self.name}_0" if self.inputs[0] is None else self.inputs[0].name(),
output_name=self.name
))
return list_actions
@operator_register("MaxPooling")
class MaxPoolCPP(ExportNode):
def __init__(self, node):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment