diff --git a/aidge_core/export_utils/export_registry.py b/aidge_core/export_utils/export_registry.py index 2490bfa2e11cc37ef9f40d9762795c713133bd13..989ccd367f7b85a50c974aa110036c7aa4eda404 100644 --- a/aidge_core/export_utils/export_registry.py +++ b/aidge_core/export_utils/export_registry.py @@ -41,6 +41,9 @@ class ExportLib(aidge_core.OperatorImpl): static_files: Dict[str, str] = {} # Main memory section mem_section = None + # Custom forward generation jinja file + forward_template: str = None + forward_header_template: str = None # PRIVATE # Registry of exportNode, class level dictionary, shared across all ExportLib _cls_export_node_registry = {} diff --git a/aidge_core/export_utils/scheduler_export.py b/aidge_core/export_utils/scheduler_export.py index 40823b00191aab8c6a53481a340f1f2f8a719102..90862578f04ecd0bdf5ee0bfd7643e7bc3d0a455 100644 --- a/aidge_core/export_utils/scheduler_export.py +++ b/aidge_core/export_utils/scheduler_export.py @@ -6,7 +6,7 @@ from aidge_core.export_utils import ExportLib, generate_file, copy_file from typing import List, Tuple -def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = None, memory_manager=None, memory_manager_args=None) -> None: +def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = None, memory_manager=None, memory_manager_args=None, test_mode=False) -> None: """Exports an aidge_core.Scheduler to C++ code. This function generates files for a given computation graph, including forward-pass functions, @@ -57,6 +57,8 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = :type memory_manager: callable :param memory_manager_args: Additional arguments passed to `memory_manager`. Defaults to an empty dictionary. :type memory_manager_args: dict, optional + :param test_mode: Additional argument which may be used during forward generation. + :type test_mode: bool, optional """ graphview = scheduler.graph_view() export_folder = Path().absolute() / export_folder_path @@ -150,9 +152,19 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = func_name = "model_forward" ROOT = Path(__file__).resolve().parents[0] + + forward_template = str(ROOT / "templates" / "forward.jinja") + if export_lib.forward_template != None: + forward_template = export_lib.forward_template + + list_node_names = [] + for node in list_forward_nodes: + if node.type() != "Producer": + list_node_names.append(node.name()) + generate_file( str(dnn_folder / "src" / "forward.cpp"), - str(ROOT / "templates" / "forward.jinja"), + forward_template, func_name=func_name, headers=set(list_configs), actions=list_actions, @@ -165,19 +177,26 @@ def scheduler_export(scheduler, export_folder_path: str, export_lib: ExportLib = inputs_name=inputs_name, inputs_dtype=inputs_dtype, outputs_name=outputs_name, - outputs_dtype=outputs_dtype + outputs_dtype=outputs_dtype, + test_mode=test_mode, + list_node_names=list_node_names ) + forward_header_template = str(ROOT / "templates" / "forward_header.jinja") + if export_lib.forward_header_template != None: + forward_header_template = export_lib.forward_header_template + # Generate dnn API generate_file( str(dnn_folder / "include" / "forward.hpp"), - str(ROOT / "templates" / "forward_header.jinja"), + forward_header_template, libraries=[], func_name=func_name, inputs_name=inputs_name, inputs_dtype=inputs_dtype, outputs_name=outputs_name, - outputs_dtype=outputs_dtype + outputs_dtype=outputs_dtype, + test_mode=test_mode ) if len(outputs_name) != len(outputs_dtype) or len(outputs_name) != len(outputs_size): diff --git a/python_binding/data/pybind_Data.cpp b/python_binding/data/pybind_Data.cpp index b1e879d8387a71a3819ee7e0f8bbcd1e9936c146..cdf8cc23250366c62a4102118a95e68cec28ec3d 100644 --- a/python_binding/data/pybind_Data.cpp +++ b/python_binding/data/pybind_Data.cpp @@ -10,6 +10,7 @@ ********************************************************************************/ #include <pybind11/pybind11.h> +#include <pybind11/stl.h> #include "aidge/data/Data.hpp" @@ -65,6 +66,7 @@ void init_Data(py::module& m){ m.def("format_as", (const char* (*)(DataType)) &format_as, py::arg("dt")); m.def("format_as", (const char* (*)(DataFormat)) &format_as, py::arg("df")); + m.def("get_data_format_transpose", &getDataFormatTranspose, py::arg("src"), py::arg("dst")); } } diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp index fe606cfb557042d581e09da7419d80841d1dc2d4..fc800816b7ab012cd87d6af8d451392b39563e4f 100644 --- a/python_binding/data/pybind_Tensor.cpp +++ b/python_binding/data/pybind_Tensor.cpp @@ -318,6 +318,7 @@ void init_Tensor(py::module& m){ .def("clone", &Tensor::clone) .def("sqrt", &Tensor::sqrt) .def("set_datatype", &Tensor::setDataType, py::arg("datatype"), py::arg("copyCast") = true) + .def("set_data_format", &Tensor::setDataFormat, py::arg("data_format"), py::arg("copyTrans") = true) .def("set_backend", &Tensor::setBackend, py::arg("name"), py::arg("device") = 0, py::arg("copyFrom") = true) .def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims) .def("grad", &Tensor::grad) diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index b2c03e794888a0909ada5db208fc07ad266d4ae2..e1f25e86827e81b26436876dce1b98fe0cda80b8 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -153,18 +153,20 @@ void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProd if (parent.first == node_ptr && parent.second == outputIdx) { // Add-on to display the operator's output dimensions std::string dims = ""; + std::string dtype = ""; const auto op = std::dynamic_pointer_cast<OperatorTensor>(node_ptr->getOperator()); if (op && !op->getOutput(outputIdx)->undefined()) { dims += " " + fmt::format("{}", op->getOutput(outputIdx)->dims()); + dtype += "\n" + fmt::format("{}", op->getOutput(outputIdx)->dataType()); } if (mNodes.find(child) != mNodes.end()) { - fmt::print(fp.get(), "{}_{}-->|\"{}{}→{}\"|{}_{}\n", node_ptr->type(), namePtrTable.at(node_ptr), - outputIdx, dims, inputIdx, child->type(), namePtrTable.at(child)); + fmt::print(fp.get(), "{}_{}-->|\"{}{}{}→{}\"|{}_{}\n", node_ptr->type(), namePtrTable.at(node_ptr), + outputIdx, dims, dtype, inputIdx, child->type(), namePtrTable.at(child)); } else if (verbose) { - fmt::print(fp.get(), "{}_{}-->|\"{}{}→{}\"|{}:::externalCls\n", node_ptr->type(), namePtrTable.at(node_ptr), - outputIdx, dims, inputIdx, static_cast<void*>(child.get())); + fmt::print(fp.get(), "{}_{}-->|\"{}{}{}→{}\"|{}:::externalCls\n", node_ptr->type(), namePtrTable.at(node_ptr), + outputIdx, dims, dtype, inputIdx, static_cast<void*>(child.get())); } // Do no break here because the same child can be connected to several inputs }