diff --git a/aidge_export_cpp/operators.py b/aidge_export_cpp/operators.py index 59ce94ad789b8f10c25cfe3ae03a4f865fcc2c78..0790877f8d01ad6ff2a71de52971202151cc377e 100644 --- a/aidge_export_cpp/operators.py +++ b/aidge_export_cpp/operators.py @@ -107,6 +107,21 @@ class ReshapeCPP(ExportNodeCpp): str(ROOT / "kernels" / "reshape.hpp"), ] +@ExportLibCpp.register("MatMul", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) +class MatMulCPP(ExportNodeCpp): + def __init__(self, node, mem_info): + super().__init__(node, mem_info) + self.attributes["activation"] = "Linear" + self.attributes["rescaling"] = "NoScaling" + self.config_template = str( + ROOT / "templates" / "configuration" / "matmul_config.jinja") + self.forward_template = str( + ROOT / "templates" / "kernel_forward" / "matmul_forward.jinja") + self.include_list = [] + self.kernels_to_copy = [ + str(ROOT / "kernels" / "matmul.hpp"), + ] + @ExportLibCpp.register("Conv2D", aidge_core.ImplSpec(aidge_core.IOSpec(aidge_core.dtype.float32))) class ConvCPP(ExportNodeCpp): def __init__(self, node, mem_info): diff --git a/aidge_export_cpp/templates/configuration/matmul_config.jinja b/aidge_export_cpp/templates/configuration/matmul_config.jinja index fece988ac13b0136a8506abb39998114923817d6..38316f20947fa726085bf3577ead510e6c5096f3 100644 --- a/aidge_export_cpp/templates/configuration/matmul_config.jinja +++ b/aidge_export_cpp/templates/configuration/matmul_config.jinja @@ -2,10 +2,13 @@ #ifndef {{ name|upper }}_LAYER_H #define {{ name|upper }}_LAYER_H +{% include "./_def_io.jinja" %} +{% include "./_meminfo.jinja" %} + {# For layer configuration -#} -#define {{ name|upper }}_M {{ inputs_dims[0][0] }} -#define {{ name|upper }}_K {{ inputs_dims[0][1] }} -#define {{ name|upper }}_N {{ inputs_dims[1][1] }} +#define {{ name|upper }}_M {{ in_dims[0][0] }} +#define {{ name|upper }}_K {{ in_dims[0][1] }} +#define {{ name|upper }}_N {{ in_dims[1][1] }} #define {{ name|upper }}_ACTIVATION {{ activation }} static const {{ rescaling }} {{ name|upper }}_RESCALING = {}; diff --git a/aidge_export_cpp/templates/kernel_forward/matmul_forward.jinja b/aidge_export_cpp/templates/kernel_forward/matmul_forward.jinja index ce80ffd2abc90ad611d3008c57aae36383691452..64b3df301794e1cb3d56170646a6b9524f18a6ab 100644 --- a/aidge_export_cpp/templates/kernel_forward/matmul_forward.jinja +++ b/aidge_export_cpp/templates/kernel_forward/matmul_forward.jinja @@ -1,5 +1,9 @@ +{% filter indent(width=4, first=False) %} +{% include "./_mem_offset.jinja" %} matmul_forward<{{name|upper}}_M, {{name|upper}}_K, {{name|upper}}_N, {{name|upper}}_ACTIVATION> - ({{inputs1_name}}, {{inputs2_name}}, {{outputs_name}}, {{name|upper}}_RESCALING); \ No newline at end of file + ({{in_name[0]}}, {{in_name[1]}}, {{out_name[0]}}, {{name|upper}}_RESCALING); +{% include "./_save_outputs.jinja" %} +{% endfilter %}