From b3fe7fd0d2bdfced968f3ac4ab711c319519eea0 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Wed, 24 Jul 2024 12:51:15 +0000
Subject: [PATCH] Update forward and export signature.

---
 aidge_core/export_utils/node_export.py      | 17 ++---
 aidge_core/export_utils/scheduler_export.py | 76 +++++++++------------
 2 files changed, 39 insertions(+), 54 deletions(-)

diff --git a/aidge_core/export_utils/node_export.py b/aidge_core/export_utils/node_export.py
index ce3d53501..5eba1647c 100644
--- a/aidge_core/export_utils/node_export.py
+++ b/aidge_core/export_utils/node_export.py
@@ -269,12 +269,7 @@ class ExportNodeCpp(ExportNode):
     # Path where all the kernels are stored in the export (prefixed by export_root/include)
     kernels_path: str = "kernels"
 
-    # def __init__(self, aidge_node: aidge_core.Node) -> None:
-    #     """Create ExportNode and retieve attriubtes from ``aidge_node``:
-    #     """
-    #     super().__init__()
-
-    def export(self, export_folder: str, list_configs: list):
+    def export(self, export_folder: str):
         """Define how to export the node definition.
         """
         if self.config_template is None:
@@ -302,15 +297,15 @@ class ExportNodeCpp(ExportNode):
 
         kernel_include_list.append(path_to_definition)
 
-        return list_configs + self.include_list + kernel_include_list
+        return self.include_list + kernel_include_list
 
-    def forward(self, list_actions: list):
+    def forward(self):
         """Define how to generate code to perform a forward pass.
         """
         if self.forward_template is None:
             raise ValueError("forward_template have not been defined")
-        list_actions.append(code_generation.generate_str(
+        forward_call: str = code_generation.generate_str(
             self.forward_template,
             **self.attributes
-        ))
-        return list_actions
+        )
+        return [forward_call]
diff --git a/aidge_core/export_utils/scheduler_export.py b/aidge_core/export_utils/scheduler_export.py
index 21e228c93..c12f15ab7 100644
--- a/aidge_core/export_utils/scheduler_export.py
+++ b/aidge_core/export_utils/scheduler_export.py
@@ -6,31 +6,22 @@ from aidge_core.export_utils import ExportLib, generate_file, copy_file
 from typing import List, Tuple
 
 
-class ExportScheduler():
-    def __init__(self, scheduler: aidge_core.Scheduler, *args, **kwargs):
-        self.scheduler = scheduler
-        self.graphview = scheduler.graph_view()
-
-    def export(self,
-               export_folder_path: str,
-               export_lib: ExportLib = None,
-               platform=None,
-               memory_manager=None,
-               memory_manager_args=None
-               ) -> None:
+def export(scheduler, export_folder_path: str, export_lib: ExportLib = None, memory_manager=None, memory_manager_args=None) -> None:
+        graphview = scheduler.graph_view()
         export_folder = Path().absolute() / export_folder_path
 
         os.makedirs(str(export_folder), exist_ok=True)
 
         dnn_folder = export_folder / "dnn"
         os.makedirs(str(dnn_folder), exist_ok=True)
+
         if memory_manager_args is None:
             memory_manager_args = {}
 
         if memory_manager is None:
             raise ValueError("A memory manager is required (no default value yet).")
         peak_mem, mem_info = memory_manager(
-            self.scheduler, **memory_manager_args)
+            scheduler, **memory_manager_args)
 
         # List of function call
         list_actions: List[str] = []
@@ -43,38 +34,37 @@ class ExportScheduler():
         outputs_dtype: List[str] = []
         outputs_size: List[int] = []
 
-        list_forward_nodes = self.scheduler.get_static_scheduling()
+        list_forward_nodes = scheduler.get_static_scheduling()
         # If exportLib define use it
         # else parse component in platform
-        if export_lib is not None:
-            for node in list_forward_nodes:
-                if export_lib.exportable(node):
-                    is_input = node in self.graphview.get_input_nodes()
-                    is_output = node in self.graphview.get_output_nodes()
-                    op = export_lib.get_export_node(node)(
-                        node, mem_info[node], is_input, is_output)
-                    # For configuration files
-                    list_configs = op.export(dnn_folder, list_configs)
-                    # For forward file
-                    list_actions = op.forward(list_actions)
-                    if is_input:
-                        for idx, node in enumerate(node.inputs()):
-                            if node[0] not in self.graphview.get_nodes():
-                                inputs_name.append(op.attributes["in_name"][idx])
-                                inputs_dtype.append(
-                                    op.attributes["in_cdtype"][idx]
-                                )
-                    if is_output:
-                        for idx in range(len(node.outputs())):
-                            outputs_name.append(op.attributes["out_name"][idx])
-                            outputs_dtype.append(
-                                op.attributes["out_cdtype"][idx])
-                            outputs_size.append(op.attributes["out_size"][idx])
-                else:
-                    raise RuntimeError(
-                        f"Operator not supported: {node.type()} for export lib {export_lib._name} !")
-        else:
-            raise ValueError("Current export only support export lib.")
+        if export_lib is None:
+            raise ValueError("Export need an ExportLib.")
+        for node in list_forward_nodes:
+            if export_lib.exportable(node):
+                is_input = node in graphview.get_input_nodes()
+                is_output = node in graphview.get_output_nodes()
+                op = export_lib.get_export_node(node)(
+                    node, mem_info[node], is_input, is_output)
+                # For configuration files
+                list_configs += op.export(dnn_folder)
+                # For forward file
+                list_actions += op.forward()
+                if is_input:
+                    for idx, node in enumerate(node.inputs()):
+                        if node[0] not in graphview.get_nodes():
+                            inputs_name.append(op.attributes["in_name"][idx])
+                            inputs_dtype.append(
+                                op.attributes["in_cdtype"][idx]
+                            )
+                if is_output:
+                    for idx in range(len(node.outputs())):
+                        outputs_name.append(op.attributes["out_name"][idx])
+                        outputs_dtype.append(
+                            op.attributes["out_cdtype"][idx])
+                        outputs_size.append(op.attributes["out_size"][idx])
+            else:
+                raise RuntimeError(
+                    f"Operator not supported: {node.type()} for export lib {export_lib._name} !")
 
         func_name = "model_forward"
         args = ", ".join([f"const {dtype}* {name}" for name,
-- 
GitLab