From 6f0ee9736132da39d1604b507300f2ba69324dbd Mon Sep 17 00:00:00 2001
From: Axel Farrugia <axel.farrugia@cea.fr>
Date: Fri, 8 Nov 2024 11:44:50 +0100
Subject: [PATCH] feat(export): Add a way to give a custom forward template &
 add test argument in scheduler_export() function

---
 aidge_core/export_utils/node_export.py | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/aidge_core/export_utils/node_export.py b/aidge_core/export_utils/node_export.py
index 5777814a0..9f54945de 100644
--- a/aidge_core/export_utils/node_export.py
+++ b/aidge_core/export_utils/node_export.py
@@ -362,11 +362,11 @@ class ExportNodeCpp(ExportNode):
     # List of includes to add example "include/toto.hpp"
     include_list: list = None
     # A list of path of kernels to copy in the export
-    # kernels are copied in str(export_folder / "include" / "kernels")
+    # kernels are copied in the path at the same index within the kernels_path list.
     # They are automatically added to the include list.
     kernels_to_copy: list = None
     # Path where all the kernels are stored in the export (prefixed by export_root)
-    kernels_path: str = "include/kernels"
+    kernels_path: list = None
     # Path of config folders
     config_path: str = "include/layers"
     # Config_folder_extension
@@ -392,16 +392,18 @@ class ExportNodeCpp(ExportNode):
             raise ValueError("include_list have not been defined")
         if self.kernels_to_copy is None:
             raise ValueError("kernels_to_copy have not been defined")
+        if self.kernels_path is None:
+            raise ValueError("kernels_path have not been defined")
 
         kernel_include_list = []
-        for kernel in self.kernels_to_copy:
+        for i, kernel in enumerate(self.kernels_to_copy):
             kernel_path = Path(kernel)
             code_generation.copy_file(
                 kernel_path,
-                str(export_folder / self.kernels_path)
+                str(export_folder / self.kernels_path[i])
             )
             kernel_include_list.append(
-                self.kernels_path + "/" + kernel_path.stem + kernel_path.suffix)
+                self.kernels_path[i] + "/" + kernel_path.stem + kernel_path.suffix)
 
         if self.config_template != "":
             path_to_definition = f"{self.config_path}/{self.attributes['name']}.{self.config_extension}"
-- 
GitLab