Skip to content
Snippets Groups Projects
Commit b7d1251e authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Refactor code_generation.py to use Path.

parent 11a950fd
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!115Aidge export
Pipeline #44699 passed
from aidge_core.aidge_export_aidge.utils import operator_register, parse_node_input from aidge_core.aidge_export_aidge.utils import operator_register, parse_node_input
from aidge_core.aidge_export_aidge import ROOT_EXPORT from aidge_core.aidge_export_aidge import ROOT_EXPORT
from aidge_core import ExportNode, generate_file, generate_str from aidge_core import ExportNode, generate_file, generate_str
import os
from pathlib import Path from pathlib import Path
@operator_register("Conv") @operator_register("Conv")
...@@ -9,15 +8,9 @@ class Conv(ExportNode): ...@@ -9,15 +8,9 @@ class Conv(ExportNode):
def __init__(self, node): def __init__(self, node):
super().__init__(node) super().__init__(node)
def export(self, export_folder:Path, list_configs:list): def export(self, export_folder:Path, list_configs:list):
include_path = f"attributes/{self.name}.hpp" include_path = f"attributes/{self.name}.hpp"
filepath = export_folder / f"include/{include_path}" filepath = export_folder / f"include/{include_path}"
dirname = os.path.dirname(filepath)
# If directory doesn't exist, create it
if not os.path.exists(dirname): os.makedirs(dirname)
generate_file( generate_file(
filepath, filepath,
......
from aidge_core.aidge_export_aidge.utils import operator_register,parse_node_input from aidge_core.aidge_export_aidge.utils import operator_register,parse_node_input
from aidge_core.aidge_export_aidge import ROOT_EXPORT from aidge_core.aidge_export_aidge import ROOT_EXPORT
from aidge_core import ExportNode, generate_file, generate_str from aidge_core import ExportNode, generate_file, generate_str
import os
from pathlib import Path from pathlib import Path
@operator_register("FC") @operator_register("FC")
...@@ -15,10 +14,7 @@ class FC(ExportNode): ...@@ -15,10 +14,7 @@ class FC(ExportNode):
include_path = f"attributes/{self.name}.hpp" include_path = f"attributes/{self.name}.hpp"
filepath = export_folder / f"include/{include_path}" filepath = export_folder / f"include/{include_path}"
dirname = os.path.dirname(filepath)
# If directory doesn't exist, create it
if not os.path.exists(dirname): os.makedirs(dirname)
generate_file( generate_file(
filepath, filepath,
......
from aidge_core.aidge_export_aidge.utils import operator_register, parse_node_input from aidge_core.aidge_export_aidge.utils import operator_register, parse_node_input
from aidge_core.aidge_export_aidge import ROOT_EXPORT from aidge_core.aidge_export_aidge import ROOT_EXPORT
from aidge_core import ExportNode, generate_file, generate_str from aidge_core import ExportNode, generate_file, generate_str
import os
from pathlib import Path from pathlib import Path
@operator_register("MaxPooling") @operator_register("MaxPooling")
...@@ -11,13 +10,8 @@ class MaxPooling(ExportNode): ...@@ -11,13 +10,8 @@ class MaxPooling(ExportNode):
def export(self, export_folder:Path, list_configs:list): def export(self, export_folder:Path, list_configs:list):
include_path = f"attributes/{self.name}.hpp" include_path = f"attributes/{self.name}.hpp"
filepath = export_folder / f"include/{include_path}" filepath = export_folder / f"include/{include_path}"
dirname = os.path.dirname(filepath)
# If directory doesn't exist, create it
if not os.path.exists(dirname): os.makedirs(dirname)
generate_file( generate_file(
filepath, filepath,
......
...@@ -2,7 +2,6 @@ from aidge_core.aidge_export_aidge.utils import operator_register ...@@ -2,7 +2,6 @@ from aidge_core.aidge_export_aidge.utils import operator_register
from aidge_core.aidge_export_aidge import ROOT_EXPORT from aidge_core.aidge_export_aidge import ROOT_EXPORT
from aidge_core import ExportNode, generate_file, generate_str from aidge_core import ExportNode, generate_file, generate_str
import numpy as np import numpy as np
import os
from pathlib import Path from pathlib import Path
...@@ -24,11 +23,6 @@ class Producer(ExportNode): ...@@ -24,11 +23,6 @@ class Producer(ExportNode):
include_path = f"parameters/{self.tensor_name}.hpp" include_path = f"parameters/{self.tensor_name}.hpp"
filepath = export_folder / f"include/{include_path}" filepath = export_folder / f"include/{include_path}"
dirname = os.path.dirname(filepath)
# If directory doesn't exist, create it
if not os.path.exists(dirname): os.makedirs(dirname)
aidge_tensor = self.operator.get_output(0) aidge_tensor = self.operator.get_output(0)
generate_file( generate_file(
filepath, filepath,
...@@ -43,7 +37,7 @@ class Producer(ExportNode): ...@@ -43,7 +37,7 @@ class Producer(ExportNode):
def forward(self, list_actions:list): def forward(self, list_actions:list):
list_actions.append(generate_str( list_actions.append(generate_str(
str(ROOT_EXPORT) + "/templates/graph_ctor/producer.jinja", ROOT_EXPORT / "templates/graph_ctor/producer.jinja",
name=self.name, name=self.name,
tensor_name=self.tensor_name, tensor_name=self.tensor_name,
**self.attributes **self.attributes
......
import os from pathlib import Path
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
from typing import Union
def generate_file(file_path: str, template_path: str, **kwargs) -> None: def generate_file(file_path: Union[Path, str], template_path: Union[Path, str], **kwargs) -> None:
"""Generate a file at `file_path` using the jinja template located at `file_path`. """Generate a file at `file_path` using the jinja template located at `file_path`.
kwargs are used to fill the template. kwargs are used to fill the template.
:param file_path: path where to generate the file :param file_path: path where to generate the file
:type file_path: str :type file_path: pathlib.Path or str
:param template_path: Path to the template to use for code generation :param template_path: Path to the template to use for code generation
:type template_path: str :type template_path: pathlib.Path or str
""" """
# Get directory name of the file # Convert str -> Path for compatibility !
dirname = os.path.dirname(file_path) if isinstance(file_path, str):
file_path = Path(file_path)
# If directory doesn't exist, create it if isinstance(template_path, str):
if not os.path.exists(dirname): template_path = Path(template_path)
os.makedirs(dirname) # Make dir
file_path.parent.mkdir(parents=True, exist_ok=True)
# Get directory name and name of the template
template_dir = os.path.dirname(template_path)
template_name = os.path.basename(template_path)
# Select template # Select template
template = Environment(loader=FileSystemLoader( template = Environment(loader=FileSystemLoader(
template_dir)).get_template(template_name) template_path.parent)).get_template(template_path.name)
# Generate file # Generate file
content = template.render(kwargs) with open(file_path, mode="w", encoding="utf-8") as file:
with open(file_path, mode="w", encoding="utf-8") as message: file.write(template.render(kwargs))
message.write(content)
def generate_str(template_path:str, **kwargs) -> str: def generate_str(template_path: Union[Path, str], **kwargs) -> str:
"""Generate a string using the jinja template located at `file_path`. """Generate a string using the jinja template located at `file_path`.
kwargs are used to fill the template. kwargs are used to fill the template.
:param template_path: Path to the template to use for code generation :param template_path: Path to the template to use for code generation
:type template_path: str :type template_path: pathlib.Path or str
:return: A string of the interpreted template :return: A string of the interpreted template
:rtype: str :rtype: str
""" """
dirname = os.path.dirname(template_path) # Convert str -> Path for compatibility !
filename = os.path.basename(template_path) if isinstance(template_path, str):
template = Environment(loader=FileSystemLoader(dirname)).get_template(filename) template_path = Path(template_path)
return template.render(kwargs) return Environment(loader=FileSystemLoader(
template_path.parent)).get_template(template_path.name).render(kwargs)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment