diff --git a/aidge_core/__init__.py b/aidge_core/__init__.py index 4234747b94b25b35a6836a36ad6331c1c8a4bc66..38f5faa94b3fd08287977bdc4e1af60650af69cd 100644 --- a/aidge_core/__init__.py +++ b/aidge_core/__init__.py @@ -8,6 +8,6 @@ http://www.eclipse.org/legal/epl-2.0. SPDX-License-Identifier: EPL-2.0 """ from aidge_core.aidge_core import * # import so generated by PyBind -from aidge_core.export_utils import ExportNode, generate_file, generate_str +from aidge_core.export_utils import ExportNodeCpp, ExportNode, generate_file, generate_str import aidge_core.utils from aidge_core.aidge_export_aidge import * diff --git a/aidge_core/export_utils/code_generation.py b/aidge_core/export_utils/code_generation.py index 29b24bb82314cdd85db78c71b98bbead68a0b855..d48970b0fcf247bef0c3df4052837ab78e360238 100644 --- a/aidge_core/export_utils/code_generation.py +++ b/aidge_core/export_utils/code_generation.py @@ -1,7 +1,8 @@ from pathlib import Path from jinja2 import Environment, FileSystemLoader, StrictUndefined from typing import Union - +import os +import shutil def generate_file(file_path: Union[Path, str], template_path: Union[Path, str], **kwargs) -> None: """Generate a file at `file_path` using the jinja template located at `file_path`. @@ -40,3 +41,11 @@ def generate_str(template_path: Union[Path, str], **kwargs) -> str: template_path = Path(template_path) return Environment(loader=FileSystemLoader( template_path.parent), undefined=StrictUndefined).get_template(template_path.name).render(kwargs) + +def copy_file(filename, dst_folder): + + # If directory doesn't exist, create it + if not os.path.exists(dst_folder): + os.makedirs(dst_folder) + + shutil.copy(filename, dst_folder) diff --git a/aidge_core/export_utils/node_export.py b/aidge_core/export_utils/node_export.py index 78a103aea5d16cf9a6094975cb7a066c089d20a9..d09a7474258940b6a648ff87b0157425890e335b 100644 --- a/aidge_core/export_utils/node_export.py +++ b/aidge_core/export_utils/node_export.py @@ -1,9 +1,9 @@ import aidge_core +from pathlib import Path -from aidge_core.export_utils import data_conversion +from aidge_core.export_utils import data_conversion, code_generation from abc import ABC, abstractmethod - def get_chan(tensor: aidge_core.Tensor) -> int: """Given a tensor return the number of channel """ @@ -185,14 +185,63 @@ class ExportNode(ABC): self.attributes["out_width"][idx] = get_width(tensor) else: print(f"No output for {self.node.name()}") - @abstractmethod + + +class ExportNodeCpp(ExportNode): + # Path to the template defining how to export the node definition + config_template: str = None + # Path to the template defining how to export the node definition + forward_template: str = None + # List of includes to add example "include/toto.hpp" + include_list: list = None + # A list of path of kernels to copy in the export + # kernels are copied in str(export_folder / "include" / "kernels") + # They are automatically added to the include list. + kernels_to_copy: list = None + # Path where all the kernels are stored in the export (prefixed by export_root/include) + kernels_path: str = "kernels" + + # def __init__(self, aidge_node: aidge_core.Node) -> None: + # """Create ExportNode and retieve attriubtes from ``aidge_node``: + # """ + # super().__init__() + def export(self, export_folder: str, list_configs: list): """Define how to export the node definition. """ - pass + if self.config_template is None: + raise ValueError("config_template have not been defined") + if self.include_list is None: + raise ValueError("include_list have not been defined") + if self.kernels_to_copy is None: + raise ValueError("required_kernels have not been defined") + + kernel_include_list = [] + for kernel in self.kernels_to_copy: + kernel_path = Path(kernel) + code_generation.copy_file( + kernel_path, + str(export_folder / "include" / self.kernels_path) + ) + kernel_include_list.append(self.kernels_path + "/" + kernel_path.stem + kernel_path.suffix) + path_to_definition = f"layers/{self.attributes['name']}.h" + code_generation.generate_file( + str(export_folder / path_to_definition), + self.config_template, + **self.attributes + ) + + kernel_include_list.append(path_to_definition) + + return list_configs + self.include_list + kernel_include_list - @abstractmethod def forward(self, list_actions: list): """Define how to generate code to perform a forward pass. """ - pass + if self.forward_template is None: + raise ValueError("forward_template have not been defined") + list_actions.append(code_generation.generate_str( + self.forward_template, + **self.attributes + )) + return list_actions