From 94c3fa62f65c3f4b06d6bb93c5c49e3f785c884e Mon Sep 17 00:00:00 2001 From: Axel Farrugia <axel.farrugia@cea.fr> Date: Fri, 8 Nov 2024 11:44:50 +0100 Subject: [PATCH] feat(export): Add a way to give a custom forward template & add test argument in scheduler_export() function --- aidge_core/export_utils/export_registry.py | 6 +++--- aidge_core/export_utils/node_export.py | 12 +++++++----- aidge_core/export_utils/scheduler_export.py | 12 ++++++++---- src/graph/GraphView.cpp | 10 ++++++---- 4 files changed, 24 insertions(+), 16 deletions(-) diff --git a/aidge_core/export_utils/export_registry.py b/aidge_core/export_utils/export_registry.py index fd24008a6..fa6d96343 100644 --- a/aidge_core/export_utils/export_registry.py +++ b/aidge_core/export_utils/export_registry.py @@ -29,9 +29,9 @@ class ExportLib(aidge_core.OperatorImpl): # Value: Path where to copy the file relative to the export root static_files: Dict[str, str] = {} # Custom main generation jinja file - main_jinja_path = None - # Main memory section - memory_section = None + main_jinja_path: str = None + # Custom forward generation jinja file + forward_template: str = None # PRIVATE # Registry of exportNode, class level dictionary, shared across all ExportLib _cls_export_node_registry = {} diff --git a/aidge_core/export_utils/node_export.py b/aidge_core/export_utils/node_export.py index 479eaf01f..c76e49762 100644 --- a/aidge_core/export_utils/node_export.py +++ b/aidge_core/export_utils/node_export.py @@ -268,11 +268,11 @@ class ExportNodeCpp(ExportNode): # List of includes to add example "include/toto.hpp" include_list: list = None # A list of path of kernels to copy in the export - # kernels are copied in str(export_folder / "include" / "kernels") + # kernels are copied in the path at the same index within the kernels_path list. # They are automatically added to the include list. kernels_to_copy: list = None # Path where all the kernels are stored in the export (prefixed by export_root) - kernels_path: str = "include/kernels" + kernels_path: list = None # Path of config folders config_path: str = "include/layers" # Config_folder_extension @@ -286,16 +286,18 @@ class ExportNodeCpp(ExportNode): raise ValueError("include_list have not been defined") if self.kernels_to_copy is None: raise ValueError("kernels_to_copy have not been defined") + if self.kernels_path is None: + raise ValueError("kernels_path have not been defined") kernel_include_list = [] - for kernel in self.kernels_to_copy: + for i, kernel in enumerate(self.kernels_to_copy): kernel_path = Path(kernel) code_generation.copy_file( kernel_path, - str(export_folder / self.kernels_path) + str(export_folder / self.kernels_path[i]) ) kernel_include_list.append( - self.kernels_path + "/" + kernel_path.stem + kernel_path.suffix) + self.kernels_path[i] + "/" + kernel_path.stem + kernel_path.suffix) if self.config_template != "": path_to_definition = f"{self.config_path}/{self.attributes['name']}.{self.config_extension}" diff --git a/aidge_core/export_utils/scheduler_export.py b/aidge_core/export_utils/scheduler_export.py index df0b4a385..840643a74 100644 --- a/aidge_core/export_utils/scheduler_export.py +++ b/aidge_core/export_utils/scheduler_export.py @@ -6,7 +6,7 @@ from aidge_core.export_utils import ExportLib, generate_file, copy_file from typing import List, Tuple -def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = None, memory_manager=None, memory_manager_args=None, labels=False) -> None: +def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = None, memory_manager=None, memory_manager_args=None, labels=False, test_mode=False) -> None: graphview = scheduler.graph_view() export_folder = Path().absolute() / export_folder_path @@ -107,19 +107,23 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = forward_func = f"void {func_name}({args})" ROOT = Path(__file__).resolve().parents[0] + forward_template = str(ROOT / "templates" / "forward.jinja") + if export_lib.forward_template != None: + forward_template = export_lib.forward_template + generate_file( str(dnn_folder / "src" / "forward.cpp"), - str(ROOT / "templates" / "forward.jinja"), + forward_template, func_name=func_name, headers=set(list_configs), actions=list_actions, mem_ctype=inputs_dtype[0], # Legacy behavior ... - mem_section=export_lib.mem_section, peak_mem=peak_mem, inputs_name=inputs_name, inputs_dtype=inputs_dtype, outputs_name=outputs_name, - outputs_dtype=outputs_dtype + outputs_dtype=outputs_dtype, + test_mode=test_mode ) # Generate dnn API diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index b2c03e794..e1f25e868 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -153,18 +153,20 @@ void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProd if (parent.first == node_ptr && parent.second == outputIdx) { // Add-on to display the operator's output dimensions std::string dims = ""; + std::string dtype = ""; const auto op = std::dynamic_pointer_cast<OperatorTensor>(node_ptr->getOperator()); if (op && !op->getOutput(outputIdx)->undefined()) { dims += " " + fmt::format("{}", op->getOutput(outputIdx)->dims()); + dtype += "\n" + fmt::format("{}", op->getOutput(outputIdx)->dataType()); } if (mNodes.find(child) != mNodes.end()) { - fmt::print(fp.get(), "{}_{}-->|\"{}{}→{}\"|{}_{}\n", node_ptr->type(), namePtrTable.at(node_ptr), - outputIdx, dims, inputIdx, child->type(), namePtrTable.at(child)); + fmt::print(fp.get(), "{}_{}-->|\"{}{}{}→{}\"|{}_{}\n", node_ptr->type(), namePtrTable.at(node_ptr), + outputIdx, dims, dtype, inputIdx, child->type(), namePtrTable.at(child)); } else if (verbose) { - fmt::print(fp.get(), "{}_{}-->|\"{}{}→{}\"|{}:::externalCls\n", node_ptr->type(), namePtrTable.at(node_ptr), - outputIdx, dims, inputIdx, static_cast<void*>(child.get())); + fmt::print(fp.get(), "{}_{}-->|\"{}{}{}→{}\"|{}:::externalCls\n", node_ptr->type(), namePtrTable.at(node_ptr), + outputIdx, dims, dtype, inputIdx, static_cast<void*>(child.get())); } // Do no break here because the same child can be connected to several inputs } -- GitLab