Skip to content
Snippets Groups Projects
Commit bd723d05 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Add NodeExportCpp to specify NodeExport and how to generate forward and node definition.

parent 052f6d60
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!163Export refactor
Pipeline #50574 failed
...@@ -8,6 +8,6 @@ http://www.eclipse.org/legal/epl-2.0. ...@@ -8,6 +8,6 @@ http://www.eclipse.org/legal/epl-2.0.
SPDX-License-Identifier: EPL-2.0 SPDX-License-Identifier: EPL-2.0
""" """
from aidge_core.aidge_core import * # import so generated by PyBind from aidge_core.aidge_core import * # import so generated by PyBind
from aidge_core.export_utils import ExportNode, generate_file, generate_str from aidge_core.export_utils import ExportNodeCpp, ExportNode, generate_file, generate_str
import aidge_core.utils import aidge_core.utils
from aidge_core.aidge_export_aidge import * from aidge_core.aidge_export_aidge import *
from pathlib import Path from pathlib import Path
from jinja2 import Environment, FileSystemLoader, StrictUndefined from jinja2 import Environment, FileSystemLoader, StrictUndefined
from typing import Union from typing import Union
import os
import shutil
def generate_file(file_path: Union[Path, str], template_path: Union[Path, str], **kwargs) -> None: def generate_file(file_path: Union[Path, str], template_path: Union[Path, str], **kwargs) -> None:
"""Generate a file at `file_path` using the jinja template located at `file_path`. """Generate a file at `file_path` using the jinja template located at `file_path`.
...@@ -40,3 +41,11 @@ def generate_str(template_path: Union[Path, str], **kwargs) -> str: ...@@ -40,3 +41,11 @@ def generate_str(template_path: Union[Path, str], **kwargs) -> str:
template_path = Path(template_path) template_path = Path(template_path)
return Environment(loader=FileSystemLoader( return Environment(loader=FileSystemLoader(
template_path.parent), undefined=StrictUndefined).get_template(template_path.name).render(kwargs) template_path.parent), undefined=StrictUndefined).get_template(template_path.name).render(kwargs)
def copy_file(filename, dst_folder):
# If directory doesn't exist, create it
if not os.path.exists(dst_folder):
os.makedirs(dst_folder)
shutil.copy(filename, dst_folder)
import aidge_core import aidge_core
from pathlib import Path
from aidge_core.export_utils import data_conversion from aidge_core.export_utils import data_conversion, code_generation
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
def get_chan(tensor: aidge_core.Tensor) -> int: def get_chan(tensor: aidge_core.Tensor) -> int:
"""Given a tensor return the number of channel """Given a tensor return the number of channel
""" """
...@@ -185,14 +185,63 @@ class ExportNode(ABC): ...@@ -185,14 +185,63 @@ class ExportNode(ABC):
self.attributes["out_width"][idx] = get_width(tensor) self.attributes["out_width"][idx] = get_width(tensor)
else: else:
print(f"No output for {self.node.name()}") print(f"No output for {self.node.name()}")
@abstractmethod
class ExportNodeCpp(ExportNode):
# Path to the template defining how to export the node definition
config_template: str = None
# Path to the template defining how to export the node definition
forward_template: str = None
# List of includes to add example "include/toto.hpp"
include_list: list = None
# A list of path of kernels to copy in the export
# kernels are copied in str(export_folder / "include" / "kernels")
# They are automatically added to the include list.
kernels_to_copy: list = None
# Path where all the kernels are stored in the export (prefixed by export_root/include)
kernels_path: str = "kernels"
# def __init__(self, aidge_node: aidge_core.Node) -> None:
# """Create ExportNode and retieve attriubtes from ``aidge_node``:
# """
# super().__init__()
def export(self, export_folder: str, list_configs: list): def export(self, export_folder: str, list_configs: list):
"""Define how to export the node definition. """Define how to export the node definition.
""" """
pass if self.config_template is None:
raise ValueError("config_template have not been defined")
if self.include_list is None:
raise ValueError("include_list have not been defined")
if self.kernels_to_copy is None:
raise ValueError("required_kernels have not been defined")
kernel_include_list = []
for kernel in self.kernels_to_copy:
kernel_path = Path(kernel)
code_generation.copy_file(
kernel_path,
str(export_folder / "include" / self.kernels_path)
)
kernel_include_list.append(self.kernels_path + "/" + kernel_path.stem + kernel_path.suffix)
path_to_definition = f"layers/{self.attributes['name']}.h"
code_generation.generate_file(
str(export_folder / path_to_definition),
self.config_template,
**self.attributes
)
kernel_include_list.append(path_to_definition)
return list_configs + self.include_list + kernel_include_list
@abstractmethod
def forward(self, list_actions: list): def forward(self, list_actions: list):
"""Define how to generate code to perform a forward pass. """Define how to generate code to perform a forward pass.
""" """
pass if self.forward_template is None:
raise ValueError("forward_template have not been defined")
list_actions.append(code_generation.generate_str(
self.forward_template,
**self.attributes
))
return list_actions
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment