diff --git a/aidge_core/export_utils/export_registry.py b/aidge_core/export_utils/export_registry.py index fd24008a6de6c58c1e78f088e086817e2a769373..fa6d9634332bf5c7058be4c44a0d6d9e9fb61459 100644 --- a/aidge_core/export_utils/export_registry.py +++ b/aidge_core/export_utils/export_registry.py @@ -29,9 +29,9 @@ class ExportLib(aidge_core.OperatorImpl): # Value: Path where to copy the file relative to the export root static_files: Dict[str, str] = {} # Custom main generation jinja file - main_jinja_path = None - # Main memory section - memory_section = None + main_jinja_path: str = None + # Custom forward generation jinja file + forward_template: str = None # PRIVATE # Registry of exportNode, class level dictionary, shared across all ExportLib _cls_export_node_registry = {} diff --git a/aidge_core/export_utils/node_export.py b/aidge_core/export_utils/node_export.py index 479eaf01ff8c8e85a3bf83adac88f5ee7fe86857..c76e4976240365a440e0a13dc51518adfc568f46 100644 --- a/aidge_core/export_utils/node_export.py +++ b/aidge_core/export_utils/node_export.py @@ -268,11 +268,11 @@ class ExportNodeCpp(ExportNode): # List of includes to add example "include/toto.hpp" include_list: list = None # A list of path of kernels to copy in the export - # kernels are copied in str(export_folder / "include" / "kernels") + # kernels are copied in the path at the same index within the kernels_path list. # They are automatically added to the include list. kernels_to_copy: list = None # Path where all the kernels are stored in the export (prefixed by export_root) - kernels_path: str = "include/kernels" + kernels_path: list = None # Path of config folders config_path: str = "include/layers" # Config_folder_extension @@ -286,16 +286,18 @@ class ExportNodeCpp(ExportNode): raise ValueError("include_list have not been defined") if self.kernels_to_copy is None: raise ValueError("kernels_to_copy have not been defined") + if self.kernels_path is None: + raise ValueError("kernels_path have not been defined") kernel_include_list = [] - for kernel in self.kernels_to_copy: + for i, kernel in enumerate(self.kernels_to_copy): kernel_path = Path(kernel) code_generation.copy_file( kernel_path, - str(export_folder / self.kernels_path) + str(export_folder / self.kernels_path[i]) ) kernel_include_list.append( - self.kernels_path + "/" + kernel_path.stem + kernel_path.suffix) + self.kernels_path[i] + "/" + kernel_path.stem + kernel_path.suffix) if self.config_template != "": path_to_definition = f"{self.config_path}/{self.attributes['name']}.{self.config_extension}" diff --git a/aidge_core/export_utils/scheduler_export.py b/aidge_core/export_utils/scheduler_export.py index df0b4a385327e4bdccd6fe4de46043d151658dbd..840643a74f3df27d24ada345ea598d191ce95967 100644 --- a/aidge_core/export_utils/scheduler_export.py +++ b/aidge_core/export_utils/scheduler_export.py @@ -6,7 +6,7 @@ from aidge_core.export_utils import ExportLib, generate_file, copy_file from typing import List, Tuple -def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = None, memory_manager=None, memory_manager_args=None, labels=False) -> None: +def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = None, memory_manager=None, memory_manager_args=None, labels=False, test_mode=False) -> None: graphview = scheduler.graph_view() export_folder = Path().absolute() / export_folder_path @@ -107,19 +107,23 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = forward_func = f"void {func_name}({args})" ROOT = Path(__file__).resolve().parents[0] + forward_template = str(ROOT / "templates" / "forward.jinja") + if export_lib.forward_template != None: + forward_template = export_lib.forward_template + generate_file( str(dnn_folder / "src" / "forward.cpp"), - str(ROOT / "templates" / "forward.jinja"), + forward_template, func_name=func_name, headers=set(list_configs), actions=list_actions, mem_ctype=inputs_dtype[0], # Legacy behavior ... - mem_section=export_lib.mem_section, peak_mem=peak_mem, inputs_name=inputs_name, inputs_dtype=inputs_dtype, outputs_name=outputs_name, - outputs_dtype=outputs_dtype + outputs_dtype=outputs_dtype, + test_mode=test_mode ) # Generate dnn API diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index b2c03e794888a0909ada5db208fc07ad266d4ae2..e1f25e86827e81b26436876dce1b98fe0cda80b8 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -153,18 +153,20 @@ void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProd if (parent.first == node_ptr && parent.second == outputIdx) { // Add-on to display the operator's output dimensions std::string dims = ""; + std::string dtype = ""; const auto op = std::dynamic_pointer_cast<OperatorTensor>(node_ptr->getOperator()); if (op && !op->getOutput(outputIdx)->undefined()) { dims += " " + fmt::format("{}", op->getOutput(outputIdx)->dims()); + dtype += "\n" + fmt::format("{}", op->getOutput(outputIdx)->dataType()); } if (mNodes.find(child) != mNodes.end()) { - fmt::print(fp.get(), "{}_{}-->|\"{}{}→{}\"|{}_{}\n", node_ptr->type(), namePtrTable.at(node_ptr), - outputIdx, dims, inputIdx, child->type(), namePtrTable.at(child)); + fmt::print(fp.get(), "{}_{}-->|\"{}{}{}→{}\"|{}_{}\n", node_ptr->type(), namePtrTable.at(node_ptr), + outputIdx, dims, dtype, inputIdx, child->type(), namePtrTable.at(child)); } else if (verbose) { - fmt::print(fp.get(), "{}_{}-->|\"{}{}→{}\"|{}:::externalCls\n", node_ptr->type(), namePtrTable.at(node_ptr), - outputIdx, dims, inputIdx, static_cast<void*>(child.get())); + fmt::print(fp.get(), "{}_{}-->|\"{}{}{}→{}\"|{}:::externalCls\n", node_ptr->type(), namePtrTable.at(node_ptr), + outputIdx, dims, dtype, inputIdx, static_cast<void*>(child.get())); } // Do no break here because the same child can be connected to several inputs }