Skip to content
Snippets Groups Projects
Commit b957fede authored by Gallas Gaye's avatar Gallas Gaye Committed by Gallas Gaye
Browse files

chore: Refactor duplicate code in operators

parent 534f5b33
No related branches found
No related tags found
2 merge requests!39Update 0.2.1 -> 0.3.0,!31Add missing operators for basic onnx model exporting
...@@ -122,6 +122,27 @@ class MatMulCPP(ExportNodeCpp): ...@@ -122,6 +122,27 @@ class MatMulCPP(ExportNodeCpp):
str(ROOT / "kernels" / "matmul.hpp"), str(ROOT / "kernels" / "matmul.hpp"),
] ]
def _setup_conv2D(conv):
"""Common setup code for convolutions: Conv2D and PaddedConv2D."""
# If biases are not provided we set it as nullptr instead of None
if (len(conv.attributes["in_name"]) > 2 and conv.attributes["in_name"][2] is None):
conv.attributes["in_name"][2] = "nullptr"
conv.attributes["activation"] = "Linear"
conv.attributes["rescaling"] = "NoScaling"
conv.config_template = str(
ROOT / "templates" / "configuration" / "convolution_config.jinja")
conv.forward_template = str(
ROOT / "templates" / "kernel_forward" / "convolution_forward.jinja")
conv.include_list = []
conv.kernels_to_copy = [
str(ROOT / "kernels" / "convolution.hpp"),
str(ROOT / "kernels" / "macs.hpp"),
str(ROOT / "kernels" / "activation.hpp"),
str(ROOT / "kernels" / "rescaling.hpp")
]
@ExportLibCpp.register("Conv2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) @ExportLibCpp.register("Conv2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
class ConvCPP(ExportNodeCpp): class ConvCPP(ExportNodeCpp):
def __init__(self, node, mem_info): def __init__(self, node, mem_info):
...@@ -129,19 +150,8 @@ class ConvCPP(ExportNodeCpp): ...@@ -129,19 +150,8 @@ class ConvCPP(ExportNodeCpp):
# No padding with Conv # No padding with Conv
# Use PaddedConv to add padding attribute # Use PaddedConv to add padding attribute
self.attributes["padding"] = [0, 0] self.attributes["padding"] = [0, 0]
self.attributes["activation"] = "Linear"
self.attributes["rescaling"] = "NoScaling" _setup_conv2D(self)
self.config_template = str(
ROOT / "templates" / "configuration" / "convolution_config.jinja")
self.forward_template = str(
ROOT / "templates" / "kernel_forward" / "convolution_forward.jinja")
self.include_list = []
self.kernels_to_copy = [
str(ROOT / "kernels" / "convolution.hpp"),
str(ROOT / "kernels" / "macs.hpp"),
str(ROOT / "kernels" / "activation.hpp"),
str(ROOT / "kernels" / "rescaling.hpp")
]
@ExportLibCpp.register_metaop("PaddedConv2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) @ExportLibCpp.register_metaop("PaddedConv2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
class PaddedConvCPP(ExportNodeCpp): class PaddedConvCPP(ExportNodeCpp):
...@@ -159,74 +169,60 @@ class PaddedConvCPP(ExportNodeCpp): ...@@ -159,74 +169,60 @@ class PaddedConvCPP(ExportNodeCpp):
).attr.stride_dims ).attr.stride_dims
self.attributes["dilation_dims"] = n.get_operator( self.attributes["dilation_dims"] = n.get_operator(
).attr.dilation_dims ).attr.dilation_dims
self.attributes["activation"] = "Linear"
self.attributes["rescaling"] = "NoScaling" _setup_conv2D(self)
self.config_template = str(
ROOT / "templates" / "configuration" / "convolution_config.jinja") def _setup_elemwise_op(elemwise, op):
self.forward_template = str( """Common code (template and kernel setup) shared across all the different elementWise operator (Add, Sub,...)."""
ROOT / "templates" / "kernel_forward" / "convolution_forward.jinja")
self.include_list = [] elemwise.attributes["elemwise_op"] = op
self.kernels_to_copy = [ elemwise.attributes["activation"] = "Linear"
str(ROOT / "kernels" / "convolution.hpp"), elemwise.attributes["rescaling"] = "NoScaling"
str(ROOT / "kernels" / "macs.hpp"), elemwise.config_template = str(
str(ROOT / "kernels" / "activation.hpp"), ROOT / "templates" / "configuration" / "elemwise_config.jinja")
str(ROOT / "kernels" / "rescaling.hpp") elemwise.forward_template = str(
] ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja")
elemwise.include_list = []
elemwise.kernels_to_copy = [
str(ROOT / "kernels" / "elemwise.hpp"),
str(ROOT / "kernels" / "activation.hpp"),
str(ROOT / "kernels" / "rescaling.hpp")
]
@ExportLibCpp.register("Add", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) @ExportLibCpp.register("Add", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
class AddCPP(ExportNodeCpp): class AddCPP(ExportNodeCpp):
def __init__(self, node, mem_info): def __init__(self, node, mem_info):
super().__init__(node, mem_info) super().__init__(node, mem_info)
self.attributes["elemwise_op"] = "Add"
self.attributes["activation"] = "Linear" _setup_elemwise_op(self, "Add")
self.attributes["rescaling"] = "NoScaling"
self.config_template = str(
ROOT / "templates" / "configuration" / "elemwise_config.jinja")
self.forward_template = str(
ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja")
self.include_list = []
self.kernels_to_copy = [
str(ROOT / "kernels" / "elemwise.hpp"),
str(ROOT / "kernels" / "activation.hpp"),
str(ROOT / "kernels" / "rescaling.hpp")
]
@ExportLibCpp.register("Sub", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) @ExportLibCpp.register("Sub", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
class SubCPP(ExportNodeCpp): class SubCPP(ExportNodeCpp):
def __init__(self, node, mem_info): def __init__(self, node, mem_info):
super().__init__(node, mem_info) super().__init__(node, mem_info)
self.attributes["elemwise_op"] = "Sub"
self.attributes["activation"] = "Linear"
self.attributes["rescaling"] = "NoScaling"
self.config_template = str(
ROOT / "templates" / "configuration" / "elemwise_config.jinja")
self.forward_template = str(
ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja")
self.include_list = []
self.kernels_to_copy = [
str(ROOT / "kernels" / "elemwise.hpp"),
str(ROOT / "kernels" / "activation.hpp"),
str(ROOT / "kernels" / "rescaling.hpp")
]
_setup_elemwise_op(self, "Sub")
@ExportLibCpp.register("Mul", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) @ExportLibCpp.register("Mul", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
class MulCPP(ExportNodeCpp): class MulCPP(ExportNodeCpp):
def __init__(self, node, mem_info): def __init__(self, node, mem_info):
super().__init__(node, mem_info) super().__init__(node, mem_info)
self.attributes["elemwise_op"] = "Mul"
self.attributes["activation"] = "Linear" _setup_elemwise_op(self, "Mul")
self.attributes["rescaling"] = "NoScaling"
self.config_template = str( def _setup_pooling(pooling):
ROOT / "templates" / "configuration" / "elemwise_config.jinja") """Common code (template and kernel setup) shared across all the different pooling operator."""
self.forward_template = str(
ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja") pooling.config_template = str(
self.include_list = [] ROOT / "templates" / "configuration" / "pooling_config.jinja")
self.kernels_to_copy = [ pooling.forward_template = str(
str(ROOT / "kernels" / "elemwise.hpp"), ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja")
str(ROOT / "kernels" / "activation.hpp"), pooling.include_list = []
str(ROOT / "kernels" / "rescaling.hpp") pooling.kernels_to_copy = [
] str(ROOT / "kernels" / "pooling.hpp"),
str(ROOT / "kernels" / "activation.hpp"),
str(ROOT / "kernels" / "rescaling.hpp")
]
@ExportLibCpp.register("MaxPooling2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) @ExportLibCpp.register("MaxPooling2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
class MaxPoolCPP(ExportNodeCpp): class MaxPoolCPP(ExportNodeCpp):
...@@ -239,17 +235,7 @@ class MaxPoolCPP(ExportNodeCpp): ...@@ -239,17 +235,7 @@ class MaxPoolCPP(ExportNodeCpp):
self.attributes["pool_type"] = "Max" self.attributes["pool_type"] = "Max"
self.attributes["activation"] = "Linear" self.attributes["activation"] = "Linear"
self.config_template = str( _setup_pooling(self)
ROOT / "templates" / "configuration" / "pooling_config.jinja")
self.forward_template = str(
ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja")
self.include_list = []
self.kernels_to_copy = [
str(ROOT / "kernels" / "pooling.hpp"),
str(ROOT / "kernels" / "activation.hpp"),
str(ROOT / "kernels" / "rescaling.hpp")
]
@ExportLibCpp.register_metaop("PaddedMaxPooling2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) @ExportLibCpp.register_metaop("PaddedMaxPooling2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
class PaddedMaxPoolCPP(ExportNodeCpp): class PaddedMaxPoolCPP(ExportNodeCpp):
...@@ -267,17 +253,7 @@ class PaddedMaxPoolCPP(ExportNodeCpp): ...@@ -267,17 +253,7 @@ class PaddedMaxPoolCPP(ExportNodeCpp):
self.attributes["pool_type"] = "Max" self.attributes["pool_type"] = "Max"
self.attributes["activation"] = "Linear" self.attributes["activation"] = "Linear"
self.config_template = str( _setup_pooling(self)
ROOT / "templates" / "configuration" / "pooling_config.jinja")
self.forward_template = str(
ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja")
self.include_list = []
self.kernels_to_copy = [
str(ROOT / "kernels" / "pooling.hpp"),
str(ROOT / "kernels" / "activation.hpp"),
str(ROOT / "kernels" / "rescaling.hpp")
]
@ExportLibCpp.register("GlobalAveragePooling", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) @ExportLibCpp.register("GlobalAveragePooling", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
class GlobalAveragePoolCPP(ExportNodeCpp): class GlobalAveragePoolCPP(ExportNodeCpp):
...@@ -295,16 +271,7 @@ class GlobalAveragePoolCPP(ExportNodeCpp): ...@@ -295,16 +271,7 @@ class GlobalAveragePoolCPP(ExportNodeCpp):
self.attributes["pool_type"] = "Average" self.attributes["pool_type"] = "Average"
self.attributes["activation"] = "Linear" self.attributes["activation"] = "Linear"
self.config_template = str( _setup_pooling(self)
ROOT / "templates" / "configuration" / "pooling_config.jinja")
self.forward_template = str(
ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja")
self.include_list = []
self.kernels_to_copy = [
str(ROOT / "kernels" / "pooling.hpp"),
str(ROOT / "kernels" / "activation.hpp"),
str(ROOT / "kernels" / "rescaling.hpp")
]
@ExportLibCpp.register("FC", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) @ExportLibCpp.register("FC", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32)))
class FcCPP(ExportNodeCpp): class FcCPP(ExportNodeCpp):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment