From b957fede729a835fc0fbeff9e5e797580afe7a58 Mon Sep 17 00:00:00 2001
From: Gallas Gaye <gallasko@gmail.com>
Date: Fri, 21 Feb 2025 10:43:48 +0100
Subject: [PATCH] chore: Refactor duplicate code in operators

---
 aidge_export_cpp/operators.py | 161 ++++++++++++++--------------------
 1 file changed, 64 insertions(+), 97 deletions(-)

diff --git a/aidge_export_cpp/operators.py b/aidge_export_cpp/operators.py
index 0790877..346928f 100644
--- a/aidge_export_cpp/operators.py
+++ b/aidge_export_cpp/operators.py
@@ -122,6 +122,27 @@ class MatMulCPP(ExportNodeCpp):
             str(ROOT / "kernels" / "matmul.hpp"),
         ]
 
+def _setup_conv2D(conv):
+    """Common setup code for convolutions: Conv2D and PaddedConv2D."""
+
+    # If biases are not provided we set it as nullptr instead of None
+    if (len(conv.attributes["in_name"]) > 2 and conv.attributes["in_name"][2] is None):
+        conv.attributes["in_name"][2] = "nullptr"
+
+    conv.attributes["activation"] = "Linear"
+    conv.attributes["rescaling"] = "NoScaling"
+    conv.config_template = str(
+        ROOT / "templates" / "configuration" / "convolution_config.jinja")
+    conv.forward_template = str(
+        ROOT / "templates" / "kernel_forward" / "convolution_forward.jinja")
+    conv.include_list = []
+    conv.kernels_to_copy = [
+        str(ROOT / "kernels" / "convolution.hpp"),
+        str(ROOT / "kernels" / "macs.hpp"),
+        str(ROOT / "kernels" / "activation.hpp"),
+        str(ROOT / "kernels" / "rescaling.hpp")
+    ]
+
 @ExportLibCpp.register("Conv2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
 class ConvCPP(ExportNodeCpp):
     def __init__(self, node, mem_info):
@@ -129,19 +150,8 @@ class ConvCPP(ExportNodeCpp):
         # No padding with Conv
         # Use PaddedConv to add padding attribute
         self.attributes["padding"] = [0, 0]
-        self.attributes["activation"] = "Linear"
-        self.attributes["rescaling"] = "NoScaling"
-        self.config_template = str(
-            ROOT / "templates" / "configuration" / "convolution_config.jinja")
-        self.forward_template = str(
-            ROOT / "templates" / "kernel_forward" / "convolution_forward.jinja")
-        self.include_list = []
-        self.kernels_to_copy = [
-            str(ROOT / "kernels" / "convolution.hpp"),
-            str(ROOT / "kernels" / "macs.hpp"),
-            str(ROOT / "kernels" / "activation.hpp"),
-            str(ROOT / "kernels" / "rescaling.hpp")
-        ]
+
+        _setup_conv2D(self)
 
 @ExportLibCpp.register_metaop("PaddedConv2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
 class PaddedConvCPP(ExportNodeCpp):
@@ -159,74 +169,60 @@ class PaddedConvCPP(ExportNodeCpp):
                 ).attr.stride_dims
                 self.attributes["dilation_dims"] = n.get_operator(
                 ).attr.dilation_dims
-        self.attributes["activation"] = "Linear"
-        self.attributes["rescaling"] = "NoScaling"
-        self.config_template = str(
-            ROOT / "templates" / "configuration" / "convolution_config.jinja")
-        self.forward_template = str(
-            ROOT / "templates" / "kernel_forward" / "convolution_forward.jinja")
-        self.include_list = []
-        self.kernels_to_copy = [
-            str(ROOT / "kernels" / "convolution.hpp"),
-            str(ROOT / "kernels" / "macs.hpp"),
-            str(ROOT / "kernels" / "activation.hpp"),
-            str(ROOT / "kernels" / "rescaling.hpp")
-        ]
+
+        _setup_conv2D(self)
+
+def _setup_elemwise_op(elemwise, op):
+    """Common code (template and kernel setup) shared across all the different elementWise operator (Add, Sub,...)."""
+
+    elemwise.attributes["elemwise_op"] = op
+    elemwise.attributes["activation"] = "Linear"
+    elemwise.attributes["rescaling"] = "NoScaling"
+    elemwise.config_template = str(
+        ROOT / "templates" / "configuration" / "elemwise_config.jinja")
+    elemwise.forward_template = str(
+        ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja")
+    elemwise.include_list = []
+    elemwise.kernels_to_copy = [
+        str(ROOT / "kernels" / "elemwise.hpp"),
+        str(ROOT / "kernels" / "activation.hpp"),
+        str(ROOT / "kernels" / "rescaling.hpp")
+    ]
 
 @ExportLibCpp.register("Add", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
 class AddCPP(ExportNodeCpp):
     def __init__(self, node, mem_info):
         super().__init__(node, mem_info)
-        self.attributes["elemwise_op"] = "Add"
-        self.attributes["activation"] = "Linear"
-        self.attributes["rescaling"] = "NoScaling"
-        self.config_template = str(
-            ROOT / "templates" / "configuration" / "elemwise_config.jinja")
-        self.forward_template = str(
-            ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja")
-        self.include_list = []
-        self.kernels_to_copy = [
-            str(ROOT / "kernels" / "elemwise.hpp"),
-            str(ROOT / "kernels" / "activation.hpp"),
-            str(ROOT / "kernels" / "rescaling.hpp")
-        ]
+
+        _setup_elemwise_op(self, "Add")
 
 @ExportLibCpp.register("Sub", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
 class SubCPP(ExportNodeCpp):
     def __init__(self, node, mem_info):
         super().__init__(node, mem_info)
-        self.attributes["elemwise_op"] = "Sub"
-        self.attributes["activation"] = "Linear"
-        self.attributes["rescaling"] = "NoScaling"
-        self.config_template = str(
-            ROOT / "templates" / "configuration" / "elemwise_config.jinja")
-        self.forward_template = str(
-            ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja")
-        self.include_list = []
-        self.kernels_to_copy = [
-            str(ROOT / "kernels" / "elemwise.hpp"),
-            str(ROOT / "kernels" / "activation.hpp"),
-            str(ROOT / "kernels" / "rescaling.hpp")
-        ]
 
+        _setup_elemwise_op(self, "Sub")
 
 @ExportLibCpp.register("Mul", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
 class MulCPP(ExportNodeCpp):
     def __init__(self, node, mem_info):
         super().__init__(node, mem_info)
-        self.attributes["elemwise_op"] = "Mul"
-        self.attributes["activation"] = "Linear"
-        self.attributes["rescaling"] = "NoScaling"
-        self.config_template = str(
-            ROOT / "templates" / "configuration" / "elemwise_config.jinja")
-        self.forward_template = str(
-            ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja")
-        self.include_list = []
-        self.kernels_to_copy = [
-            str(ROOT / "kernels" / "elemwise.hpp"),
-            str(ROOT / "kernels" / "activation.hpp"),
-            str(ROOT / "kernels" / "rescaling.hpp")
-        ]
+
+        _setup_elemwise_op(self, "Mul")
+
+def _setup_pooling(pooling):
+    """Common code (template and kernel setup) shared across all the different pooling operator."""
+
+    pooling.config_template = str(
+        ROOT / "templates" / "configuration" / "pooling_config.jinja")
+    pooling.forward_template = str(
+        ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja")
+    pooling.include_list = []
+    pooling.kernels_to_copy = [
+        str(ROOT / "kernels" / "pooling.hpp"),
+        str(ROOT / "kernels" / "activation.hpp"),
+        str(ROOT / "kernels" / "rescaling.hpp")
+    ]
 
 @ExportLibCpp.register("MaxPooling2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
 class MaxPoolCPP(ExportNodeCpp):
@@ -239,17 +235,7 @@ class MaxPoolCPP(ExportNodeCpp):
         self.attributes["pool_type"] = "Max"
         self.attributes["activation"] = "Linear"
 
-        self.config_template = str(
-            ROOT / "templates" / "configuration" / "pooling_config.jinja")
-        self.forward_template = str(
-            ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja")
-        self.include_list = []
-        self.kernels_to_copy = [
-            str(ROOT / "kernels" / "pooling.hpp"),
-            str(ROOT / "kernels" / "activation.hpp"),
-            str(ROOT / "kernels" / "rescaling.hpp")
-        ]
-
+        _setup_pooling(self)
 
 @ExportLibCpp.register_metaop("PaddedMaxPooling2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
 class PaddedMaxPoolCPP(ExportNodeCpp):
@@ -267,17 +253,7 @@ class PaddedMaxPoolCPP(ExportNodeCpp):
         self.attributes["pool_type"] = "Max"
         self.attributes["activation"] = "Linear"
 
-        self.config_template = str(
-            ROOT / "templates" / "configuration" / "pooling_config.jinja")
-        self.forward_template = str(
-            ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja")
-        self.include_list = []
-        self.kernels_to_copy = [
-            str(ROOT / "kernels" / "pooling.hpp"),
-            str(ROOT / "kernels" / "activation.hpp"),
-            str(ROOT / "kernels" / "rescaling.hpp")
-        ]
-
+        _setup_pooling(self)
 
 @ExportLibCpp.register("GlobalAveragePooling", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
 class GlobalAveragePoolCPP(ExportNodeCpp):
@@ -295,16 +271,7 @@ class GlobalAveragePoolCPP(ExportNodeCpp):
         self.attributes["pool_type"] = "Average"
         self.attributes["activation"] = "Linear"
 
-        self.config_template = str(
-            ROOT / "templates" / "configuration" / "pooling_config.jinja")
-        self.forward_template = str(
-            ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja")
-        self.include_list = []
-        self.kernels_to_copy = [
-            str(ROOT / "kernels" / "pooling.hpp"),
-            str(ROOT / "kernels" / "activation.hpp"),
-            str(ROOT / "kernels" / "rescaling.hpp")
-        ]
+        _setup_pooling(self)
 
 @ExportLibCpp.register("FC", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
 class FcCPP(ExportNodeCpp):
-- 
GitLab