From c92d411dca4a234c7bfc608da59e7471091670f7 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Wed, 4 Sep 2024 12:10:08 +0000
Subject: [PATCH] Add kernels copy for N2D2 kernels.

---
 aidge_export_arm_cortexm/operators.py | 18 +++++++++++++-----
 1 file changed, 13 insertions(+), 5 deletions(-)

diff --git a/aidge_export_arm_cortexm/operators.py b/aidge_export_arm_cortexm/operators.py
index bbc111d..24611dd 100644
--- a/aidge_export_arm_cortexm/operators.py
+++ b/aidge_export_arm_cortexm/operators.py
@@ -188,7 +188,7 @@ class ReLU_ARMCortexM(ExportNodeCpp):
         self.forward_template = str(ROOT / "_Aidge_Arm" / "templates" / "forward_call" / "relu.jinja")
         self.include_list = []
         self.kernels_to_copy = [
-            str(ROOT / "_Aidge_Arm" / "kernels" / "Relu" / "aidge_relu_float32.c"),
+            str(ROOT / "_Aidge_Arm" / "kernels" / "Relu" / "aidge_relu_float32.h"),
         ]
     @classmethod
     def exportable(cls, node):
@@ -207,7 +207,9 @@ class Conv_ARMCortexM(ExportNodeCpp):
         self.config_template = str(ROOT / "_Aidge_Arm" / "templates" / "configuration" / "conv_config.jinja")
         self.forward_template = str(ROOT / "_Aidge_Arm" / "templates" / "forward_call" / "conv_kernel.jinja")
         self.include_list = []
-        self.kernels_to_copy = []
+        self.kernels_to_copy = [
+            str(ROOT / "_Aidge_Arm" / "kernels" / "Convolution" / "Conv.hpp")
+        ]
     @classmethod
     def exportable(cls, node):
         return True # TODO add check i/o NCHW
@@ -233,7 +235,9 @@ class PaddedConv_ARMCortexM(ExportNodeCpp):
         self.config_template = str(ROOT / "_Aidge_Arm" / "templates" / "configuration" / "conv_config.jinja")
         self.forward_template = str(ROOT / "_Aidge_Arm" / "templates" / "forward_call" / "conv_kernel.jinja")
         self.include_list = []
-        self.kernels_to_copy = []
+        self.kernels_to_copy = [
+            str(ROOT / "_Aidge_Arm" / "kernels" / "Convolution" / "Conv.hpp")
+        ]
     @classmethod
     def exportable(cls, node):
         return True # TODO add check i/o NCHW
@@ -253,7 +257,9 @@ class Pooling_ARMCortexM(ExportNodeCpp):
         self.config_template = str(ROOT / "_Aidge_Arm" / "templates" / "configuration" / "pool_config.jinja")
         self.forward_template = str(ROOT / "_Aidge_Arm" / "templates" / "forward_call" / "pool_kernel.jinja")
         self.include_list = []
-        self.kernels_to_copy = []
+        self.kernels_to_copy = [
+            str(ROOT / "_Aidge_Arm" / "kernels" / "Pooling" / "Pooling.hpp")
+        ]
         self.kernel = node.get_operator().attr.kernel_dims
         self.stride = node.get_operator().attr.stride_dims
     @classmethod
@@ -273,7 +279,9 @@ class FC_ARMCortexM(ExportNodeCpp):
         self.config_template = str(ROOT / "_Aidge_Arm" / "templates" / "configuration" / "fc_config.jinja")
         self.forward_template = str(ROOT / "_Aidge_Arm" / "templates" / "forward_call" / "fc_kernel.jinja")
         self.include_list = []
-        self.kernels_to_copy = []
+        self.kernels_to_copy = [
+            str(ROOT / "_Aidge_Arm" / "kernels" / "FullyConnected" / "Fc.hpp")
+        ]
     @classmethod
     def exportable(cls, node):
         return True # TODO add check i/o NCHW
-- 
GitLab