From b957fede729a835fc0fbeff9e5e797580afe7a58 Mon Sep 17 00:00:00 2001 From: Gallas Gaye <gallasko@gmail.com> Date: Fri, 21 Feb 2025 10:43:48 +0100 Subject: [PATCH] chore: Refactor duplicate code in operators --- aidge_export_cpp/operators.py | 161 ++++++++++++++-------------------- 1 file changed, 64 insertions(+), 97 deletions(-) diff --git a/aidge_export_cpp/operators.py b/aidge_export_cpp/operators.py index 0790877..346928f 100644 --- a/aidge_export_cpp/operators.py +++ b/aidge_export_cpp/operators.py @@ -122,6 +122,27 @@ class MatMulCPP(ExportNodeCpp): str(ROOT / "kernels" / "matmul.hpp"), ] +def _setup_conv2D(conv): + """Common setup code for convolutions: Conv2D and PaddedConv2D.""" + + # If biases are not provided we set it as nullptr instead of None + if (len(conv.attributes["in_name"]) > 2 and conv.attributes["in_name"][2] is None): + conv.attributes["in_name"][2] = "nullptr" + + conv.attributes["activation"] = "Linear" + conv.attributes["rescaling"] = "NoScaling" + conv.config_template = str( + ROOT / "templates" / "configuration" / "convolution_config.jinja") + conv.forward_template = str( + ROOT / "templates" / "kernel_forward" / "convolution_forward.jinja") + conv.include_list = [] + conv.kernels_to_copy = [ + str(ROOT / "kernels" / "convolution.hpp"), + str(ROOT / "kernels" / "macs.hpp"), + str(ROOT / "kernels" / "activation.hpp"), + str(ROOT / "kernels" / "rescaling.hpp") + ] + @ExportLibCpp.register("Conv2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) class ConvCPP(ExportNodeCpp): def __init__(self, node, mem_info): @@ -129,19 +150,8 @@ class ConvCPP(ExportNodeCpp): # No padding with Conv # Use PaddedConv to add padding attribute self.attributes["padding"] = [0, 0] - self.attributes["activation"] = "Linear" - self.attributes["rescaling"] = "NoScaling" - self.config_template = str( - ROOT / "templates" / "configuration" / "convolution_config.jinja") - self.forward_template = str( - ROOT / "templates" / "kernel_forward" / "convolution_forward.jinja") - self.include_list = [] - self.kernels_to_copy = [ - str(ROOT / "kernels" / "convolution.hpp"), - str(ROOT / "kernels" / "macs.hpp"), - str(ROOT / "kernels" / "activation.hpp"), - str(ROOT / "kernels" / "rescaling.hpp") - ] + + _setup_conv2D(self) @ExportLibCpp.register_metaop("PaddedConv2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) class PaddedConvCPP(ExportNodeCpp): @@ -159,74 +169,60 @@ class PaddedConvCPP(ExportNodeCpp): ).attr.stride_dims self.attributes["dilation_dims"] = n.get_operator( ).attr.dilation_dims - self.attributes["activation"] = "Linear" - self.attributes["rescaling"] = "NoScaling" - self.config_template = str( - ROOT / "templates" / "configuration" / "convolution_config.jinja") - self.forward_template = str( - ROOT / "templates" / "kernel_forward" / "convolution_forward.jinja") - self.include_list = [] - self.kernels_to_copy = [ - str(ROOT / "kernels" / "convolution.hpp"), - str(ROOT / "kernels" / "macs.hpp"), - str(ROOT / "kernels" / "activation.hpp"), - str(ROOT / "kernels" / "rescaling.hpp") - ] + + _setup_conv2D(self) + +def _setup_elemwise_op(elemwise, op): + """Common code (template and kernel setup) shared across all the different elementWise operator (Add, Sub,...).""" + + elemwise.attributes["elemwise_op"] = op + elemwise.attributes["activation"] = "Linear" + elemwise.attributes["rescaling"] = "NoScaling" + elemwise.config_template = str( + ROOT / "templates" / "configuration" / "elemwise_config.jinja") + elemwise.forward_template = str( + ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja") + elemwise.include_list = [] + elemwise.kernels_to_copy = [ + str(ROOT / "kernels" / "elemwise.hpp"), + str(ROOT / "kernels" / "activation.hpp"), + str(ROOT / "kernels" / "rescaling.hpp") + ] @ExportLibCpp.register("Add", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) class AddCPP(ExportNodeCpp): def __init__(self, node, mem_info): super().__init__(node, mem_info) - self.attributes["elemwise_op"] = "Add" - self.attributes["activation"] = "Linear" - self.attributes["rescaling"] = "NoScaling" - self.config_template = str( - ROOT / "templates" / "configuration" / "elemwise_config.jinja") - self.forward_template = str( - ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja") - self.include_list = [] - self.kernels_to_copy = [ - str(ROOT / "kernels" / "elemwise.hpp"), - str(ROOT / "kernels" / "activation.hpp"), - str(ROOT / "kernels" / "rescaling.hpp") - ] + + _setup_elemwise_op(self, "Add") @ExportLibCpp.register("Sub", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) class SubCPP(ExportNodeCpp): def __init__(self, node, mem_info): super().__init__(node, mem_info) - self.attributes["elemwise_op"] = "Sub" - self.attributes["activation"] = "Linear" - self.attributes["rescaling"] = "NoScaling" - self.config_template = str( - ROOT / "templates" / "configuration" / "elemwise_config.jinja") - self.forward_template = str( - ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja") - self.include_list = [] - self.kernels_to_copy = [ - str(ROOT / "kernels" / "elemwise.hpp"), - str(ROOT / "kernels" / "activation.hpp"), - str(ROOT / "kernels" / "rescaling.hpp") - ] + _setup_elemwise_op(self, "Sub") @ExportLibCpp.register("Mul", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) class MulCPP(ExportNodeCpp): def __init__(self, node, mem_info): super().__init__(node, mem_info) - self.attributes["elemwise_op"] = "Mul" - self.attributes["activation"] = "Linear" - self.attributes["rescaling"] = "NoScaling" - self.config_template = str( - ROOT / "templates" / "configuration" / "elemwise_config.jinja") - self.forward_template = str( - ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja") - self.include_list = [] - self.kernels_to_copy = [ - str(ROOT / "kernels" / "elemwise.hpp"), - str(ROOT / "kernels" / "activation.hpp"), - str(ROOT / "kernels" / "rescaling.hpp") - ] + + _setup_elemwise_op(self, "Mul") + +def _setup_pooling(pooling): + """Common code (template and kernel setup) shared across all the different pooling operator.""" + + pooling.config_template = str( + ROOT / "templates" / "configuration" / "pooling_config.jinja") + pooling.forward_template = str( + ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja") + pooling.include_list = [] + pooling.kernels_to_copy = [ + str(ROOT / "kernels" / "pooling.hpp"), + str(ROOT / "kernels" / "activation.hpp"), + str(ROOT / "kernels" / "rescaling.hpp") + ] @ExportLibCpp.register("MaxPooling2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) class MaxPoolCPP(ExportNodeCpp): @@ -239,17 +235,7 @@ class MaxPoolCPP(ExportNodeCpp): self.attributes["pool_type"] = "Max" self.attributes["activation"] = "Linear" - self.config_template = str( - ROOT / "templates" / "configuration" / "pooling_config.jinja") - self.forward_template = str( - ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja") - self.include_list = [] - self.kernels_to_copy = [ - str(ROOT / "kernels" / "pooling.hpp"), - str(ROOT / "kernels" / "activation.hpp"), - str(ROOT / "kernels" / "rescaling.hpp") - ] - + _setup_pooling(self) @ExportLibCpp.register_metaop("PaddedMaxPooling2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) class PaddedMaxPoolCPP(ExportNodeCpp): @@ -267,17 +253,7 @@ class PaddedMaxPoolCPP(ExportNodeCpp): self.attributes["pool_type"] = "Max" self.attributes["activation"] = "Linear" - self.config_template = str( - ROOT / "templates" / "configuration" / "pooling_config.jinja") - self.forward_template = str( - ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja") - self.include_list = [] - self.kernels_to_copy = [ - str(ROOT / "kernels" / "pooling.hpp"), - str(ROOT / "kernels" / "activation.hpp"), - str(ROOT / "kernels" / "rescaling.hpp") - ] - + _setup_pooling(self) @ExportLibCpp.register("GlobalAveragePooling", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) class GlobalAveragePoolCPP(ExportNodeCpp): @@ -295,16 +271,7 @@ class GlobalAveragePoolCPP(ExportNodeCpp): self.attributes["pool_type"] = "Average" self.attributes["activation"] = "Linear" - self.config_template = str( - ROOT / "templates" / "configuration" / "pooling_config.jinja") - self.forward_template = str( - ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja") - self.include_list = [] - self.kernels_to_copy = [ - str(ROOT / "kernels" / "pooling.hpp"), - str(ROOT / "kernels" / "activation.hpp"), - str(ROOT / "kernels" / "rescaling.hpp") - ] + _setup_pooling(self) @ExportLibCpp.register("FC", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) class FcCPP(ExportNodeCpp): -- GitLab