diff --git a/aidge_core/aidge_export_aidge/operator_export/conv.py b/aidge_core/aidge_export_aidge/operator_export/conv.py index 52558980bd3d943e6e8fc0144f37b2f06d4231a4..fb7092fb18982a3cc3f11a1ca47394ce2f77d0b6 100644 --- a/aidge_core/aidge_export_aidge/operator_export/conv.py +++ b/aidge_core/aidge_export_aidge/operator_export/conv.py @@ -1,7 +1,6 @@ from aidge_core.aidge_export_aidge.utils import operator_register, parse_node_input from aidge_core.aidge_export_aidge import ROOT_EXPORT from aidge_core import ExportNode, generate_file, generate_str -import os from pathlib import Path @operator_register("Conv") @@ -9,15 +8,9 @@ class Conv(ExportNode): def __init__(self, node): super().__init__(node) - def export(self, export_folder:Path, list_configs:list): - include_path = f"attributes/{self.name}.hpp" filepath = export_folder / f"include/{include_path}" - dirname = os.path.dirname(filepath) - - # If directory doesn't exist, create it - if not os.path.exists(dirname): os.makedirs(dirname) generate_file( filepath, diff --git a/aidge_core/aidge_export_aidge/operator_export/fc.py b/aidge_core/aidge_export_aidge/operator_export/fc.py index af17486911531144c3f63a83a20ebabffecff735..03d4060b27951f00902f16186b381b8dbb5504ac 100644 --- a/aidge_core/aidge_export_aidge/operator_export/fc.py +++ b/aidge_core/aidge_export_aidge/operator_export/fc.py @@ -1,7 +1,6 @@ from aidge_core.aidge_export_aidge.utils import operator_register,parse_node_input from aidge_core.aidge_export_aidge import ROOT_EXPORT from aidge_core import ExportNode, generate_file, generate_str -import os from pathlib import Path @operator_register("FC") @@ -15,10 +14,7 @@ class FC(ExportNode): include_path = f"attributes/{self.name}.hpp" filepath = export_folder / f"include/{include_path}" - dirname = os.path.dirname(filepath) - # If directory doesn't exist, create it - if not os.path.exists(dirname): os.makedirs(dirname) generate_file( filepath, diff --git a/aidge_core/aidge_export_aidge/operator_export/maxpooling.py b/aidge_core/aidge_export_aidge/operator_export/maxpooling.py index a88d31b1aa655f107c97edc4a96f37c9decf99a9..0c63e71b423b90f62536cafd25c61101e76e0562 100644 --- a/aidge_core/aidge_export_aidge/operator_export/maxpooling.py +++ b/aidge_core/aidge_export_aidge/operator_export/maxpooling.py @@ -1,7 +1,6 @@ from aidge_core.aidge_export_aidge.utils import operator_register, parse_node_input from aidge_core.aidge_export_aidge import ROOT_EXPORT from aidge_core import ExportNode, generate_file, generate_str -import os from pathlib import Path @operator_register("MaxPooling") @@ -11,13 +10,8 @@ class MaxPooling(ExportNode): def export(self, export_folder:Path, list_configs:list): - include_path = f"attributes/{self.name}.hpp" filepath = export_folder / f"include/{include_path}" - dirname = os.path.dirname(filepath) - - # If directory doesn't exist, create it - if not os.path.exists(dirname): os.makedirs(dirname) generate_file( filepath, diff --git a/aidge_core/aidge_export_aidge/operator_export/producer.py b/aidge_core/aidge_export_aidge/operator_export/producer.py index 4d8e6fe1c0831a0546b294f801a28d9c86541cf1..870ec319af470c8882b45402d3952de60dd0327d 100644 --- a/aidge_core/aidge_export_aidge/operator_export/producer.py +++ b/aidge_core/aidge_export_aidge/operator_export/producer.py @@ -2,7 +2,6 @@ from aidge_core.aidge_export_aidge.utils import operator_register from aidge_core.aidge_export_aidge import ROOT_EXPORT from aidge_core import ExportNode, generate_file, generate_str import numpy as np -import os from pathlib import Path @@ -24,11 +23,6 @@ class Producer(ExportNode): include_path = f"parameters/{self.tensor_name}.hpp" filepath = export_folder / f"include/{include_path}" - dirname = os.path.dirname(filepath) - - # If directory doesn't exist, create it - if not os.path.exists(dirname): os.makedirs(dirname) - aidge_tensor = self.operator.get_output(0) generate_file( filepath, @@ -43,7 +37,7 @@ class Producer(ExportNode): def forward(self, list_actions:list): list_actions.append(generate_str( - str(ROOT_EXPORT) + "/templates/graph_ctor/producer.jinja", + ROOT_EXPORT / "templates/graph_ctor/producer.jinja", name=self.name, tensor_name=self.tensor_name, **self.attributes diff --git a/aidge_core/export_utils/code_generation.py b/aidge_core/export_utils/code_generation.py index b18b5476f8e083bcbe3d4f6c4a57132ebe7b780f..a02fc0966702cec7a2cbe33f8411bb71e3035e90 100644 --- a/aidge_core/export_utils/code_generation.py +++ b/aidge_core/export_utils/code_generation.py @@ -1,47 +1,46 @@ -import os +from pathlib import Path from jinja2 import Environment, FileSystemLoader +from typing import Union -def generate_file(file_path: str, template_path: str, **kwargs) -> None: +def generate_file(file_path: Union[Path, str], template_path: Union[Path, str], **kwargs) -> None: """Generate a file at `file_path` using the jinja template located at `file_path`. kwargs are used to fill the template. :param file_path: path where to generate the file - :type file_path: str + :type file_path: pathlib.Path or str :param template_path: Path to the template to use for code generation - :type template_path: str + :type template_path: pathlib.Path or str """ - # Get directory name of the file - dirname = os.path.dirname(file_path) - - # If directory doesn't exist, create it - if not os.path.exists(dirname): - os.makedirs(dirname) - - # Get directory name and name of the template - template_dir = os.path.dirname(template_path) - template_name = os.path.basename(template_path) + # Convert str -> Path for compatibility ! + if isinstance(file_path, str): + file_path = Path(file_path) + if isinstance(template_path, str): + template_path = Path(template_path) + # Make dir + file_path.parent.mkdir(parents=True, exist_ok=True) # Select template template = Environment(loader=FileSystemLoader( - template_dir)).get_template(template_name) + template_path.parent)).get_template(template_path.name) # Generate file - content = template.render(kwargs) - with open(file_path, mode="w", encoding="utf-8") as message: - message.write(content) + with open(file_path, mode="w", encoding="utf-8") as file: + file.write(template.render(kwargs)) + -def generate_str(template_path:str, **kwargs) -> str: +def generate_str(template_path: Union[Path, str], **kwargs) -> str: """Generate a string using the jinja template located at `file_path`. kwargs are used to fill the template. :param template_path: Path to the template to use for code generation - :type template_path: str + :type template_path: pathlib.Path or str :return: A string of the interpreted template :rtype: str """ - dirname = os.path.dirname(template_path) - filename = os.path.basename(template_path) - template = Environment(loader=FileSystemLoader(dirname)).get_template(filename) - return template.render(kwargs) + # Convert str -> Path for compatibility ! + if isinstance(template_path, str): + template_path = Path(template_path) + return Environment(loader=FileSystemLoader( + template_path.parent)).get_template(template_path.name).render(kwargs)