From bd723d05f52f917eb043e6434a12dcb467b4a357 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Wed, 10 Jul 2024 09:38:55 +0000
Subject: [PATCH] Add NodeExportCpp to specify NodeExport and how to generate
 forward and node definition.

---
 aidge_core/__init__.py                     |  2 +-
 aidge_core/export_utils/code_generation.py | 11 +++-
 aidge_core/export_utils/node_export.py     | 61 +++++++++++++++++++---
 3 files changed, 66 insertions(+), 8 deletions(-)

diff --git a/aidge_core/__init__.py b/aidge_core/__init__.py
index 4234747b9..38f5faa94 100644
--- a/aidge_core/__init__.py
+++ b/aidge_core/__init__.py
@@ -8,6 +8,6 @@ http://www.eclipse.org/legal/epl-2.0.
 SPDX-License-Identifier: EPL-2.0
 """
 from aidge_core.aidge_core import * # import so generated by PyBind
-from aidge_core.export_utils import ExportNode, generate_file, generate_str
+from aidge_core.export_utils import ExportNodeCpp, ExportNode, generate_file, generate_str
 import aidge_core.utils
 from aidge_core.aidge_export_aidge import *
diff --git a/aidge_core/export_utils/code_generation.py b/aidge_core/export_utils/code_generation.py
index 29b24bb82..d48970b0f 100644
--- a/aidge_core/export_utils/code_generation.py
+++ b/aidge_core/export_utils/code_generation.py
@@ -1,7 +1,8 @@
 from pathlib import Path
 from jinja2 import Environment, FileSystemLoader, StrictUndefined
 from typing import Union
-
+import os
+import shutil
 
 def generate_file(file_path: Union[Path, str], template_path: Union[Path, str], **kwargs) -> None:
     """Generate a file at `file_path` using the jinja template located at `file_path`.
@@ -40,3 +41,11 @@ def generate_str(template_path: Union[Path, str], **kwargs) -> str:
         template_path = Path(template_path)
     return Environment(loader=FileSystemLoader(
         template_path.parent), undefined=StrictUndefined).get_template(template_path.name).render(kwargs)
+
+def copy_file(filename, dst_folder):
+
+    # If directory doesn't exist, create it
+    if not os.path.exists(dst_folder):
+        os.makedirs(dst_folder)
+
+    shutil.copy(filename, dst_folder)
diff --git a/aidge_core/export_utils/node_export.py b/aidge_core/export_utils/node_export.py
index 78a103aea..d09a74742 100644
--- a/aidge_core/export_utils/node_export.py
+++ b/aidge_core/export_utils/node_export.py
@@ -1,9 +1,9 @@
 import aidge_core
+from pathlib import Path
 
-from aidge_core.export_utils import data_conversion
+from aidge_core.export_utils import data_conversion, code_generation
 from abc import ABC, abstractmethod
 
-
 def get_chan(tensor: aidge_core.Tensor) -> int:
     """Given a tensor return the number of channel
     """
@@ -185,14 +185,63 @@ class ExportNode(ABC):
                 self.attributes["out_width"][idx] = get_width(tensor)
             else:
                 print(f"No output for {self.node.name()}")
-    @abstractmethod
+
+
+class ExportNodeCpp(ExportNode):
+    # Path to the template defining how to export the node definition
+    config_template: str = None
+    # Path to the template defining how to export the node definition
+    forward_template: str = None
+    # List of includes to add example "include/toto.hpp"
+    include_list: list = None
+    # A list of path of kernels to copy in the export
+    # kernels are copied in str(export_folder / "include" / "kernels")
+    # They are automatically added to the include list.
+    kernels_to_copy: list = None
+    # Path where all the kernels are stored in the export (prefixed by export_root/include)
+    kernels_path: str = "kernels"
+
+    # def __init__(self, aidge_node: aidge_core.Node) -> None:
+    #     """Create ExportNode and retieve attriubtes from ``aidge_node``:
+    #     """
+    #     super().__init__()
+
     def export(self, export_folder: str, list_configs: list):
         """Define how to export the node definition.
         """
-        pass
+        if self.config_template is None:
+            raise ValueError("config_template have not been defined")
+        if self.include_list is None:
+            raise ValueError("include_list have not been defined")
+        if self.kernels_to_copy is None:
+            raise ValueError("required_kernels have not been defined")
+
+        kernel_include_list = []
+        for kernel in self.kernels_to_copy:
+            kernel_path = Path(kernel)
+            code_generation.copy_file(
+                kernel_path,
+                str(export_folder / "include" / self.kernels_path)
+            )
+            kernel_include_list.append(self.kernels_path + "/" + kernel_path.stem + kernel_path.suffix)
+        path_to_definition =  f"layers/{self.attributes['name']}.h"
+        code_generation.generate_file(
+            str(export_folder / path_to_definition),
+            self.config_template,
+            **self.attributes
+        )
+
+        kernel_include_list.append(path_to_definition)
+
+        return list_configs + self.include_list + kernel_include_list
 
-    @abstractmethod
     def forward(self, list_actions: list):
         """Define how to generate code to perform a forward pass.
         """
-        pass
+        if self.forward_template is None:
+            raise ValueError("forward_template have not been defined")
+        list_actions.append(code_generation.generate_str(
+            self.forward_template,
+            **self.attributes
+        ))
+        return list_actions
-- 
GitLab