From 3931178a8af7f993a067ee9d568aca13e8f38dd3 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Tue, 2 Jul 2024 08:34:23 +0000 Subject: [PATCH] Add PaddedMaxPooling support --- aidge_export_cpp/operators.py | 57 +++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/aidge_export_cpp/operators.py b/aidge_export_cpp/operators.py index 752d203..f0b09ba 100644 --- a/aidge_export_cpp/operators.py +++ b/aidge_export_cpp/operators.py @@ -289,6 +289,63 @@ class SubCPP(ExportNode): )) return list_actions + +@operator_register("PaddedMaxPooling") +class MaxPoolCPP(ExportNode): + def __init__(self, node): + super().__init__(node) + for n in self.operator.get_micro_graph().get_nodes(): + if n.type() == "Pad": + self.padding = n.get_operator().get_attr("BeginEndBorders") + if n.type() == "MaxPooling": + self.kernel = n.get_operator().get_attr("KernelDims") + self.stride = n.get_operator().get_attr("StrideDims") + + if len(self.inputs_dims[0]) == 4: + # if dims == [batch, nb_channels, height, width] + # transform to [nb_channels, height, width] + self.inputs_dims[0] = self.inputs_dims[0][1:] + + if len(self.outputs_dims[0]) == 4: + # if dims == [batch, nb_outputs] + # transform to [nb_outputs, 1, 1] + self.outputs_dims[0] = self.outputs_dims[0][1:] + + def export(self, export_folder:Path, list_configs:list): + + copyfile(str(ROOT / "kernels" / "pooling.hpp"), + str(export_folder / "include" / "kernels")) + + list_configs.append("kernels/pooling.hpp") + list_configs.append(f"layers/{self.name}.h") + + generate_file( + str(export_folder / "layers" / f"{self.name}.h"), + str(ROOT / "templates" / "configuration" / "pooling_config.jinja"), + name=self.name, + input_dims=self.inputs_dims[0], + output_dims=self.outputs_dims[0], + kernel=self.kernel, + stride=self.stride, + padding=self.padding, + pool_type="Max", + activation="Linear") + + return list_configs + + def forward(self, list_actions:list): + + if not self.is_last: + list_actions.append(set_up_output(self.name, "float")) + + list_actions.append(generate_str( + str(ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja"), + name=self.name, + input_name=f"{self.name}_0" if self.inputs[0] is None else self.inputs[0].name(), + output_name=self.name + )) + return list_actions + @operator_register("MaxPooling") class MaxPoolCPP(ExportNode): def __init__(self, node): -- GitLab