From 3931178a8af7f993a067ee9d568aca13e8f38dd3 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Tue, 2 Jul 2024 08:34:23 +0000
Subject: [PATCH] Add PaddedMaxPooling support

---
 aidge_export_cpp/operators.py | 57 +++++++++++++++++++++++++++++++++++
 1 file changed, 57 insertions(+)

diff --git a/aidge_export_cpp/operators.py b/aidge_export_cpp/operators.py
index 752d203..f0b09ba 100644
--- a/aidge_export_cpp/operators.py
+++ b/aidge_export_cpp/operators.py
@@ -289,6 +289,63 @@ class SubCPP(ExportNode):
         ))
         return list_actions
 
+
+@operator_register("PaddedMaxPooling")
+class MaxPoolCPP(ExportNode):
+    def __init__(self, node):
+        super().__init__(node)
+        for n in self.operator.get_micro_graph().get_nodes():
+            if n.type() == "Pad":
+                self.padding = n.get_operator().get_attr("BeginEndBorders")
+            if n.type() == "MaxPooling":
+                self.kernel = n.get_operator().get_attr("KernelDims")
+                self.stride = n.get_operator().get_attr("StrideDims")
+
+        if len(self.inputs_dims[0]) == 4:
+            # if dims == [batch, nb_channels, height, width]
+            # transform to [nb_channels, height, width]
+            self.inputs_dims[0] = self.inputs_dims[0][1:]
+
+        if len(self.outputs_dims[0]) == 4:
+            # if dims == [batch, nb_outputs]
+            # transform to [nb_outputs, 1, 1]
+            self.outputs_dims[0] = self.outputs_dims[0][1:]
+
+    def export(self, export_folder:Path, list_configs:list):
+
+        copyfile(str(ROOT / "kernels" / "pooling.hpp"),
+                 str(export_folder / "include" / "kernels"))
+
+        list_configs.append("kernels/pooling.hpp")
+        list_configs.append(f"layers/{self.name}.h")
+
+        generate_file(
+            str(export_folder / "layers" / f"{self.name}.h"),
+            str(ROOT / "templates" / "configuration" / "pooling_config.jinja"),
+            name=self.name,
+            input_dims=self.inputs_dims[0],
+            output_dims=self.outputs_dims[0],
+            kernel=self.kernel,
+            stride=self.stride,
+            padding=self.padding,
+            pool_type="Max",
+            activation="Linear")
+
+        return list_configs
+
+    def forward(self, list_actions:list):
+
+        if not self.is_last:
+            list_actions.append(set_up_output(self.name, "float"))
+
+        list_actions.append(generate_str(
+            str(ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja"),
+            name=self.name,
+            input_name=f"{self.name}_0" if self.inputs[0] is None else self.inputs[0].name(),
+            output_name=self.name
+        ))
+        return list_actions
+
 @operator_register("MaxPooling")
 class MaxPoolCPP(ExportNode):
     def __init__(self, node):
-- 
GitLab