From 5d9072be15594fbeec1e5a58854ad4662e136c32 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Fri, 7 Mar 2025 01:52:04 +0000 Subject: [PATCH 1/5] [Add][WIP] benchmark scripts - main performance measurement script - specific functions for measuring onnxruntime and torch libraries - comparison graph generation script --- benchmark/benchmark.py | 326 +++++++++++++++++++++++++++++ benchmark/benchmark_onnxruntime.py | 39 ++++ benchmark/benchmark_torch.py | 65 ++++++ benchmark/generate_graph.py | 256 ++++++++++++++++++++++ 4 files changed, 686 insertions(+) create mode 100644 benchmark/benchmark.py create mode 100644 benchmark/benchmark_onnxruntime.py create mode 100644 benchmark/benchmark_torch.py create mode 100644 benchmark/generate_graph.py diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py new file mode 100644 index 000000000..6726be68a --- /dev/null +++ b/benchmark/benchmark.py @@ -0,0 +1,326 @@ +""" +Operator Kernel Performance Benchmarking Script + +This script benchmarks operator kernels using a specified inference module. +It supports timing measurements and (optionally) output comparisons with ONNXRuntime. +The configuration is provided via a JSON file. +""" + +import argparse +import copy +import importlib +import json +import os +import sys + +import numpy as np +import onnx + +import aidge_core as ai +import aidge_onnx +from aidge_onnx.generate_singleop_onnx import create_onnx_model + +# Benchmark parameters +NB_WARMUPS: int = 10 +NB_ITERATIONS: int = 50 + + +def load_inference_module(module_name: str): + """ + Dynamically imports and returns the inference module. + Exits if the module is not installed. + """ + try: + module = importlib.import_module(module_name) + print(f"'{module_name}' module successfully imported") + return module + except ImportError: + print( + f"Error: {module_name} is not installed. Install it using 'pip install {module_name}'." + ) + sys.exit(1) + + +def load_config(config_file: str) -> dict: + """ + Loads and returns the JSON configuration from the given file. + """ + with open(config_file, "r") as f: + return json.load(f) + + +def update_test_config( + param: str, + value, + base_attributes: dict, + base_input_shapes: list, + other_parameters: dict, + operator_attributes: list, +): + """ + Updates the operator attributes and input shapes based on the test parameter. + + Returns: + tuple: (updated_attributes, updated_input_shapes) or (None, None) if keys are missing. + """ + attributes = copy.deepcopy(base_attributes) + input_shapes = copy.deepcopy(base_input_shapes) + + # Update if the parameter is a valid operator attribute + if param in operator_attributes: + attributes[param] = value + + try: + extra_attrs = other_parameters[param][str(value)]["attributes"] + except KeyError: + print( + f"'{param}': '{value}': 'attributes' - Key not found in other_parameters. Config file may be ill-formed." + ) + return None, None + attributes.update(extra_attrs) + + try: + extra_input_shapes = other_parameters[param][str(value)]["input_shapes"] + except KeyError: + print( + f"'{param}': '{value}': 'input_shapes' - Key not found in other_parameters. Config file may be ill-formed." + ) + return None, None + + for shape_update in extra_input_shapes: + name, new_shape = shape_update + for base_shape in input_shapes: + if base_shape[0] == name: + base_shape[1] = new_shape + break + + return attributes, input_shapes + + +# def get_results_file_path(module_name: str, operator_aidge: str, save_directory: str) -> str: +# """ +# Constructs and returns the file path for saving the benchmark results. +# """ +# if module_name == "onnxruntime": +# filename = f"{operator_aidge}_onnxruntime.json" +# elif module_name == "pytorch": +# filename = f"{operator_aidge}_pytorch.json" +# else: +# filename = f"{operator_aidge}.json" +# return os.path.join(save_directory, filename) + + +def measure_inference_time( + module_name: str, model: onnx.ModelProto, input_data, inference_module=None +) -> list[float]: + """ + Measures inference time using the appropriate benchmark function. + """ + if module_name == "onnxruntime": + import benchmark_onnxruntime + + return benchmark_onnxruntime.measure_inference_time( + model, {v[0]: v[1] for v in input_data}, NB_WARMUPS, NB_ITERATIONS + ) + elif module_name == "torch": + import benchmark_torch + + return benchmark_torch.measure_inference_time( + model, {v[0]: v[1] for v in input_data}, NB_WARMUPS, NB_ITERATIONS + ) + else: + model = aidge_onnx.load(model=model) if "aidge" in module_name else model + return inference_module.benchmark.measure_inference_time( + model, input_data, NB_WARMUPS, NB_ITERATIONS + ) + + +def compute_output( + module_name: str, model: onnx.ModelProto, input_data, inference_module +) -> list[np.ndarray]: + """ + Measures inference time using the appropriate benchmark function. + """ + if module_name == "onnxruntime": + import benchmark_onnxruntime + + return benchmark_onnxruntime.compute_output( + model, {v[0]: v[1] for v in input_data} + ) + elif module_name == "torch": + import benchmark_torch + + return benchmark_torch.compute_output(model, {v[0]: v[1] for v in input_data}) + else: + model = aidge_onnx.load(model=model) if "aidge" in module_name else model + return inference_module.benchmark.compute_output(model, input_data) + + +def prepare_input_data( + input_shapes: list[str, list[int]], initializer_rank: int +) -> list[str, np.ndarray]: + """ + Generates random input data for the first `initializer_rank` inputs. + """ + data: list[str, np.ndarray] = [] + for i, conf in enumerate(input_shapes): + name, shape = conf + if i < initializer_rank: + random_array = np.array(np.random.rand(*shape)).astype(np.float32) + data.append((name, random_array)) + return data + + +def main(): + parser = argparse.ArgumentParser( + description="Operator Kernel Performance Benchmarking" + ) + parser.add_argument( + "--config-file", + "-cf", + type=str, + required=True, + help="Path to configuration JSON with operator type, attributes, and input sizes.", + ) + parser.add_argument( + "--module-to-bench", + "-mtb", + type=str, + required=True, + help="Name of the module containing the inference functions", + ) + parser.add_argument( + "--compare-with-onnxruntime", + "-cwo", + action="store_true", + help="Compare output with ONNXRuntime", + ) + parser.add_argument( + "--time", "-t", action="store_true", help="Compute inference time" + ) + parser.add_argument( + "--save-directory", + type=str, + required=True, + help="Directory to save the results", + ) + args = parser.parse_args() + + compare_mode = args.compare_with_onnxruntime + time_mode = args.time + module_name = args.module_to_bench + save_directory = args.save_directory + + # Load the inference module + inference_module = load_inference_module(module_name) + + # Configure aidge logging + ai.Log.set_console_level(ai.Level.Error) + ai.Log.set_precision(10) + + # Load configuration + config = load_config(args.config_file) + operator_name: str = config["operator"] + opset_version: int = config["opset_version"] + initializer_rank: int = config.get("initializer_rank", 1) + + base_input_shapes: list[str, list[int]] = config["base_configuration"][ + "input_shapes" + ] + base_attributes: dict = config["base_configuration"].get("attributes", {}) + + main_parameters: dict[str, Any] = config["test_configuration"].get( + "main_parameters", {} + ) + other_parameters: dict[str, dict] = config["test_configuration"].get( + "other_parameters", {} + ) + + # Get operator attributes from the schema for filtering test parameters + operator_schema = onnx.defs.get_schema(operator_name, opset_version) + operator_attributes: list[str] = list(operator_schema.attributes) + + # Create a base ONNX model to determine the operator type for naming the results file + base_model: onnx.ModelProto = create_onnx_model( + operator_name, + opset_version, + base_input_shapes, + initializer_rank, + **base_attributes, + ) + operator_aidge: str = list(aidge_onnx.load(model=base_model).get_ordered_outputs())[0][0].type() + + # Initialize or load existing benchmark results + results = {"library": "", "compare": {}, "time": {}} + filename: str = f"{operator_aidge.lower()}_{module_name}.json" + results_file_path = os.path.join(save_directory, filename) + # results_file_path = get_results_file_path(module_name, operator_aidge, save_directory) + if os.path.exists(results_file_path): + with open(results_file_path, "r") as f: + results = json.load(f) + results["library"] = module_name + + # Loop over each test parameter and its values + print("Starting tests...") + for param, test_values in main_parameters.items(): + if time_mode: + results["time"].setdefault(param, {}) + if compare_mode: + results["compare"].setdefault(param, {}) + + for value in test_values: + print(f"▷ {param} -- {value}") + updated_attrs, updated_input_shapes = update_test_config( + param, + value, + base_attributes, + base_input_shapes, + other_parameters, + operator_attributes, + ) + if updated_attrs is None or updated_input_shapes is None: + continue # Skip this test case if configuration is missing + + # Create the updated ONNX model + model = create_onnx_model( + operator_name, + opset_version, + updated_input_shapes, + initializer_rank, + **updated_attrs, + ) + + input_data = prepare_input_data(updated_input_shapes, initializer_rank) + + if time_mode: + print(f" {'├' if compare_mode else '└'}┬─Measuring kernel inference time...") + timing = measure_inference_time(module_name, model, input_data, inference_module) + results["time"][param][str(value)] = timing + print( + f" {'│' if compare_mode else ' '}└─[ time = {np.array(timing).mean():.2e} ± {np.array(timing).std():.2e} seconds ]" + ) + + if compare_mode: + print(f" └┬─Assessing results are the same as 'onnxruntime' module...") + ref = compute_output("onnxruntime", model, input_data, None) + tested = compute_output(module_name, model, input_data, inference_module) + + if len(ref) > 1: + print("Multi-output comparison not handled yet") + sys.exit(1) + results["compare"][param][str(value)] = bool( + np.all(np.isclose(ref, tested)) + ) + print( + f" └─[ {'o' if results['compare'][param][str(value)] else 'x'} ]" + ) + print() + + # Save results + print(f"Printing results to JSON '{results_file_path}'") + with open(results_file_path, "w") as outfile: + json.dump(results, outfile, indent=4) + + +if __name__ == "__main__": + main() diff --git a/benchmark/benchmark_onnxruntime.py b/benchmark/benchmark_onnxruntime.py new file mode 100644 index 000000000..cb7336bcf --- /dev/null +++ b/benchmark/benchmark_onnxruntime.py @@ -0,0 +1,39 @@ +import numpy as np +import onnxruntime as ort +from onnx import ModelProto +import time + +def measure_inference_time(model: ModelProto, input_data: dict[str, np.ndarray], nb_warmup: int = 10, nb_iterations: int = 50) -> list[float]: + """ + Run the provided ONNX model using ONNXRuntime. + Performs 10 warm-up runs followed by 50 timed runs (using CPU process time). + + Args: + model: The ONNX model (ModelProto). + input_data: Dictionary mapping all input names to NumPy arrays. + + Returns: + List of CPU times (in seconds) for the 50 timed runs. + """ + sess_opt = ort.SessionOptions() + sess_opt.intra_op_num_threads = 1 + sess = ort.InferenceSession(model.SerializeToString(), sess_opt) + + timings = [] + # Warm-up runs. + for i in range(nb_warmup + nb_iterations): + if i < nb_warmup: + sess.run(None, input_data) + else: + start = time.process_time() + sess.run(None, input_data) + end = time.process_time() + timings.append((end - start)) + return timings + +def compute_output(model: ModelProto, input_data: dict[str, np.ndarray]) -> list[np.ndarray]: + sess = ort.InferenceSession(model.SerializeToString()) + # Run the session with the provided input_data. + outputs = sess.run(None, input_data) + # Return all outputs. + return np.array(outputs) \ No newline at end of file diff --git a/benchmark/benchmark_torch.py b/benchmark/benchmark_torch.py new file mode 100644 index 000000000..0275daa01 --- /dev/null +++ b/benchmark/benchmark_torch.py @@ -0,0 +1,65 @@ +import torch +import numpy as np +from onnx import ModelProto +from onnx2torch import convert +import time + +def measure_inference_time(model_onnx: ModelProto, input_data: dict[str, np.ndarray], nb_warmup: int = 10, nb_iterations: int = 50) -> list[float]: + """ + Run the provided PyTorch model. + Performs 10 warm-up runs followed by 50 timed runs (using CPU process time). + + Args: + model_onnx: The ONNX model. + input_data: Dictionary mapping all input names to NumPy arrays. + + Returns: + List of CPU times (in seconds) for the 50 timed runs. + """ + model = convert(model_onnx) + + device = torch.device("cpu") + model.to(device) + model.eval() + + torch.set_num_threads(1) + + inputs = [torch.tensor(v, device=device) for _, v in input_data.items()] + timings = [] + + with torch.no_grad(): + # Warm-up runs + for i in range(nb_warmup + nb_iterations): + if i < nb_warmup: + model(*inputs) + else: + start = time.process_time() + model(*inputs) + end = time.process_time() + timings.append(end - start) + return timings + +def compute_output(model_onnx: ModelProto, input_data: dict[str, np.ndarray]) -> list[np.ndarray]: + """ + Run the PyTorch model inference. + + Args: + model: The PyTorch model. + input_data: Dictionary mapping all input names to NumPy arrays. + + Returns: + The first output tensor if there is only one, else a list of output tensors. + """ + model = convert(model_onnx) + + device = torch.device("cpu") + model.to(device) + model.eval() + + inputs = [torch.tensor(v, device=device) for _, v in input_data.items()] + + with torch.no_grad(): + outputs = model(*inputs) + outputs = [o.numpy() if isinstance(o, torch.Tensor) else np.array(o) for o in outputs] + + return outputs diff --git a/benchmark/generate_graph.py b/benchmark/generate_graph.py new file mode 100644 index 000000000..720755225 --- /dev/null +++ b/benchmark/generate_graph.py @@ -0,0 +1,256 @@ +import argparse +import json +import sys +import textwrap +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns + +# Set a default style +sns.set_theme(style="ticks", palette="flare") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Compare time performance of operator kernels by plotting relative differences." + ) + parser.add_argument( + "--operator-config", "-oc", type=str, default="config.json", + help="Path to the configuration JSON file (default: config.json)" + ) + parser.add_argument( + "--ref", "-r", type=str, required=True, + help="Path to the JSON file with reference results" + ) + parser.add_argument( + "--libs", "-l", type=str, nargs='+', required=True, + help=("Paths to one or more JSON files with library results. For violin/box mode, " + "exactly one file must be provided. For bar mode, multiple files can be provided.") + ) + # parser.add_argument( + # "--plot-type", "-pt", type=str, choices=["violin", "box", "bar"], default="violin", + # help="Type of plot to use: 'violin', 'box', or 'bar' (default: violin)" + # ) + return parser.parse_args() + + +def load_json(file_path: str): + with open(file_path, 'r') as f: + return json.load(f) + + +def create_relative_difference_plots(test_parameters: dict, ref_times: dict, ref_library: str, + plot_type: str, new_times: dict = None, new_library: str = None, + libraries: list = None): + """ + Creates subplots comparing relative differences. + + For "violin" and "box" modes, `new_times` and `new_library` are used. + For "bar" mode, a list of library tuples (library_name, times) in `libraries` is used. + In bar mode the reference library (ref_library) is always added as the baseline (ratio = 1). + """ + n_params = len(test_parameters) + n_cols = 1 + if n_params > 1: + if n_params == 2 or n_params == 4: + n_cols = 2 + else: + n_cols = 3 + + n_rows = (n_params + n_cols - 1) // n_cols + + fig, axes = plt.subplots(n_rows, n_cols, figsize=(10 if n_params == 1 else 15, 5 * n_rows)) + # Ensure axes is always iterable + if n_params == 1: + axes = [axes] + axes_flat = axes.flatten() if n_params > 1 else axes + + for idx, (param_name, param_values) in enumerate(test_parameters.items()): + ax = axes_flat[idx] + # if plot_type in ["violin", "box"]: + # # Compute relative differences (%) per run + # plot_data = [] + # for val in param_values: + # new_arr = np.array(new_times[param_name][str(val)]) + # ref_arr = np.array(ref_times[param_name][str(val)]) + # rel_diff = (new_arr - ref_arr) / ref_arr * 100 + # plot_data.extend([(str(val), diff) for diff in rel_diff]) + # df = pd.DataFrame(plot_data, columns=[f'{param_name} Value', 'Relative Difference (%)']) + # # Optionally filter extreme outliers using IQR + # for col in df.select_dtypes(include='number').columns: + # Q1 = df[col].quantile(0.25) + # Q3 = df[col].quantile(0.75) + # IQR = Q3 - Q1 + # lower_bound = Q1 - IQR + # upper_bound = Q3 + IQR + # df = df[(df[col] >= lower_bound) & (df[col] <= upper_bound)] + # if plot_type == "violin": + # sns.violinplot( + # data=df, + # x=f'{param_name} Value', + # y='Relative Difference (%)', + # hue=f'{param_name} Value', + # palette='flare', + # inner='quartile', + # ax=ax + # ) + # else: # box plot + # sns.boxplot( + # data=df, + # x=f'{param_name} Value', + # y='Relative Difference (%)', + # hue=f'{param_name} Value', + # palette='flare', + # ax=ax + # ) + # if ax.get_legend() is not None: + # ax.legend_.remove() + # ax.grid(True, axis='y', alpha=0.5, color='gray') + # ax.axhline(y=0, color='red', linewidth=1) + # ax.set_ylim(-30, 150) + # stats_text = (f'Mean: {df["Relative Difference (%)"].mean():.2f}%\n' + # f'Std: {df["Relative Difference (%)"].std():.2f}%') + # ax.text(0.02, 0.98, stats_text, transform=ax.transAxes, + # verticalalignment='top', + # bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) + # elif plot_type == "bar": + # For bar plots: compute the median time ratio for each library compared to reference + data = [] + for val in param_values: + data.append((str(val), ref_library, 1.0)) # Reference baseline (ratio = 1) + for lib_name, lib_times in libraries: + lib_arr = np.array(lib_times[param_name][str(val)]) + ref_arr = np.array(ref_times[param_name][str(val)]) + median_lib = np.median(lib_arr) + median_ref = np.median(ref_arr) + ratio = median_lib / median_ref if median_ref != 0 else np.nan + data.append((str(val), lib_name, ratio)) + df_bar = pd.DataFrame(data, columns=[f'{param_name}', 'Library', 'Ratio']) + sns.barplot( + data=df_bar, + x=f'{param_name}', + y='Ratio', + hue='Library', + palette='viridis', + ax=ax, + errorbar=None + ) + ax.grid(True, axis='y', alpha=0.5, color='gray') + # Annotate bars with their ratio values + for container in ax.containers: + labels = [f'{h:.2f}' if h > 1e-6 else '' for h in container.datavalues] + ax.bar_label(container, labels=labels, padding=3) + ax.set_ylim(0, max(df_bar['Ratio'].max() * 1.1, 1.1)) + # else: + # ax.text(0.5, 0.5, "Unknown plot type", horizontalalignment='center', verticalalignment='center') + + # Remove any unused subplots + for idx in range(len(test_parameters), len(axes_flat)): + fig.delaxes(axes_flat[idx]) + if n_params == 1: + plt.tight_layout(rect=[0, 0.05, 1, 0.88]) + else: + plt.tight_layout(rect=[0, 0.05, 1, 0.93]) + + # Create a common legend (if any) at the top center + common_handles, common_labels = None, None + for ax in fig.axes: + leg = ax.get_legend() + if leg is not None: + common_handles, common_labels = ax.get_legend_handles_labels() + break + if common_handles is not None and common_labels is not None: + fig.legend(common_handles, common_labels, loc='upper center', ncol=len(common_labels), + bbox_to_anchor=(0.5, 0.99), title="Library", fontsize=14) + # Remove legends from individual subplots + for ax in axes_flat: + if ax.get_legend() is not None: + ax.get_legend().remove() + return fig + + +def main(): + args = parse_args() + config = load_json(args.operator_config) + ref_results = load_json(args.ref) + library_files = args.libs + + operator = config["operator"] + test_parameters = config["test_configuration"].get("main_parameters", {}) + + # Load reference times and library name from reference JSON + ref_times = ref_results.get("time") + ref_library = ref_results.get("library", "ref_lib") + if ref_times is None: + print("Reference JSON does not contain time results.") + sys.exit(1) + + # if args.plot_type in ["violin", "box"]: + # if len(library_files) != 1: + # print("Error: For violin/box mode, exactly one library JSON file must be provided.") + # sys.exit(1) + # comp_results = load_json(library_files[0]) + # comp_times = comp_results.get("time") + # comp_library = comp_results.get("library", "comp_lib") + # if comp_times is None: + # print("Library JSON does not contain time results.") + # sys.exit(1) + # fig = create_relative_difference_plots( + # test_parameters, ref_times, ref_library, + # plot_type=args.plot_type, new_times=comp_times, new_library=comp_library + # ) + # filename = f"{operator}_comp_{comp_library}-vs-{ref_library}.svg" + # elif args.plot_type == "bar": + libraries = [] + for lib_file in library_files: + lib_result = load_json(lib_file) + lib_times = lib_result.get("time") + lib_name = lib_result.get("library", "lib") + if lib_times is None: + print(f"Library JSON {lib_file} does not contain time results. Skipping.") + continue + libraries.append((lib_name, lib_times)) + if not libraries: + print("No valid library results available for bar plot.") + sys.exit(1) + fig = create_relative_difference_plots( + test_parameters, ref_times, ref_library, + plot_type="bar", libraries=libraries + ) + # lib_names = "-".join([name for name, _ in libraries]) + filename = f"{operator}_inference_time_comparison.svg" + # else: + # print("Unsupported plot type.") + # sys.exit(1) + + ############################## + # Prepare footer texts + footer_title = f'[{operator}] kernel relative inference time comparison' + default_config = config.get("base_configuration", {}) + + # Wrap the default configuration text to a given width. + wrapped_config = textwrap.wrap(f'Base configuration: {default_config}', width=160) + n_lines = len(wrapped_config) + config_text = "\n".join(wrapped_config) + + # Adjust the figure layout to provide extra space at the bottom. + if len(test_parameters) == 1: + plt.subplots_adjust(bottom=0.2+0.02*n_lines) + else: + plt.subplots_adjust(bottom=0.14+0.02*n_lines) + + # Add the footer title (bottom center) with fontsize 16. + fig.text(0.5, 0.035+n_lines*0.025, footer_title, ha='center', va='bottom', fontsize=18) + + # Add the default configuration text just below the title with the computed fontsize. + fig.text(0.5, 0.02, config_text, ha='center', va='bottom', fontsize=12) + + ############################ + # save + plt.savefig(filename) + print(f"Plot saved as {filename}") + + +if __name__ == "__main__": + main() -- GitLab From 2432868b646efab56819a117fbed94244c469456 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Thu, 3 Apr 2025 09:39:44 +0000 Subject: [PATCH 2/5] add: main export scripts for timing measurements and output display --- aidge_core/export_utils/__init__.py | 2 +- aidge_core/export_utils/generate_main.py | 124 ++++++++++++++++++ .../main_benchmark_display_output.jinja | 45 +++++++ .../main_benchmark_inference_time.jinja | 51 +++++++ 4 files changed, 221 insertions(+), 1 deletion(-) create mode 100644 aidge_core/export_utils/templates/main_benchmark_display_output.jinja create mode 100644 aidge_core/export_utils/templates/main_benchmark_inference_time.jinja diff --git a/aidge_core/export_utils/__init__.py b/aidge_core/export_utils/__init__.py index 72472eee0..87a43f6a6 100644 --- a/aidge_core/export_utils/__init__.py +++ b/aidge_core/export_utils/__init__.py @@ -3,4 +3,4 @@ from .code_generation import generate_file, generate_str, copy_file from .export_registry import ExportLib from .scheduler_export import scheduler_export from .tensor_export import tensor_to_c, generate_input_file -from .generate_main import generate_main_cpp, generate_main_compare_cpp +from .generate_main import generate_main_cpp, generate_main_compare_cpp, generate_main_inference_time_cpp, generate_main_display_output_cpp diff --git a/aidge_core/export_utils/generate_main.py b/aidge_core/export_utils/generate_main.py index 57fc68bca..288f4926d 100644 --- a/aidge_core/export_utils/generate_main.py +++ b/aidge_core/export_utils/generate_main.py @@ -129,3 +129,127 @@ def generate_main_compare_cpp(export_folder: str, graph_view: aidge_core.GraphVi outputs_dtype=outputs_dtype, outputs_size=outputs_size ) + +def generate_main_inference_time_cpp(export_folder: str, graph_view: aidge_core.GraphView, nb_iterations, nb_warmup, inputs_tensor=None) -> None: + """ + Generate a C++ file to manage the forward pass of a model using the given graph structure. + + This function extracts details from the :py:class:`aidge_core.graph_view` object, including input and output node names, data types, + and tensor sizes. It uses this data to populate a C++ file template (`main.jinja`), creating a file (`main.cpp`) + that call the `model_forward` function, which handles data flow and processing for the exported model. + + This function also generate files containing input tensor if they have been set. + + :param export_folder: Path to the folder where the generated C++ file (`main.cpp`) will be saved. + :type export_folder: str + :param graph_view: An instance of :py:class:`aidge_core.graph_view`, providing access to nodes and + ordered input/output data within the computational graph. + :type graph_view: aidge_core.graph_view + :param inputs_tensor: **For future** argument to provide tensor to use in the main function, not implemented yet! + :type inputs_tensor: None + :raises RuntimeError: If there is an inconsistency in the output arguments (names, data types, sizes), + indicating an internal bug in the graph representation. + """ + outputs_name: list[str] = [] + outputs_dtype: list[str] = [] + outputs_size: list[int] = [] + inputs_name: list[str] = [] + gv_inputs: list[tuple[aidge_core.Node, int]] = graph_view.get_ordered_inputs() + gv_outputs: list[tuple[aidge_core.Node, int]] = graph_view.get_ordered_outputs() + + for in_node, in_idx in gv_inputs: + in_node_input, in_node_input_idx = in_node.input(in_idx) + in_name = f"{in_node.name()}_input_{in_idx}" if in_node_input is None else f"{in_node_input.name()}_output_{in_node_input_idx}" + inputs_name.append(in_name) + input_tensor = in_node.get_operator().get_input(in_idx) + if input_tensor is None or input_tensor.undefined() or not input_tensor.has_impl(): + if inputs_tensor is not None: + aidge_core.Log.notice("No support for inputs_tensor argument yet.") + aidge_core.Log.notice(f"No input tensor set for {in_name}, main generated will not be functionnal after code generation.") + else: + aidge_core.Log.notice(f"No input tensor set for {in_name}, main generated will not be functionnal after code generation.") + else: + aidge_core.export_utils.generate_input_file(export_folder=export_folder, array_name=in_name, tensor=input_tensor) + + for out_node, out_id in gv_outputs: + outputs_name.append(f"{out_node.name()}_output_{out_id}") + out_tensor = out_node.get_operator().get_output(out_id) + outputs_dtype.append(data_conversion.aidge2c(out_tensor.dtype())) + outputs_size.append(out_tensor.size()) + + if len(outputs_name) != len(outputs_dtype) or len(outputs_name) != len(outputs_size): + raise RuntimeError("FATAL: Output args list does not have the same length this is an internal bug.") + + ROOT = Path(__file__).resolve().parents[0] + generate_file( + str(Path(export_folder) / "main.cpp"), + str(ROOT / "templates" / "main_benchmark_inference_time.jinja"), + func_name="model_forward", + inputs_name=inputs_name, + outputs_name=outputs_name, + outputs_dtype=outputs_dtype, + outputs_size=outputs_size, + nb_iterations=nb_iterations, + nb_warmup=nb_warmup + ) + +def generate_main_display_output_cpp(export_folder: str, graph_view: aidge_core.GraphView, inputs_tensor=None) -> None: + """ + Generate a C++ file to manage the forward pass of a model using the given graph structure. + + This function extracts details from the :py:class:`aidge_core.graph_view` object, including input and output node names, data types, + and tensor sizes. It uses this data to populate a C++ file template (`main.jinja`), creating a file (`main.cpp`) + that call the `model_forward` function, which handles data flow and processing for the exported model. + + This function also generate files containing input tensor if they have been set. + + :param export_folder: Path to the folder where the generated C++ file (`main.cpp`) will be saved. + :type export_folder: str + :param graph_view: An instance of :py:class:`aidge_core.graph_view`, providing access to nodes and + ordered input/output data within the computational graph. + :type graph_view: aidge_core.graph_view + :param inputs_tensor: **For future** argument to provide tensor to use in the main function, not implemented yet! + :type inputs_tensor: None + :raises RuntimeError: If there is an inconsistency in the output arguments (names, data types, sizes), + indicating an internal bug in the graph representation. + """ + outputs_name: list[str] = [] + outputs_dtype: list[str] = [] + outputs_size: list[int] = [] + inputs_name: list[str] = [] + gv_inputs: list[tuple[aidge_core.Node, int]] = graph_view.get_ordered_inputs() + gv_outputs: list[tuple[aidge_core.Node, int]] = graph_view.get_ordered_outputs() + + for in_node, in_idx in gv_inputs: + in_node_input, in_node_input_idx = in_node.input(in_idx) + in_name = f"{in_node.name()}_input_{in_idx}" if in_node_input is None else f"{in_node_input.name()}_output_{in_node_input_idx}" + inputs_name.append(in_name) + input_tensor = in_node.get_operator().get_input(in_idx) + if input_tensor is None or input_tensor.undefined() or not input_tensor.has_impl(): + if inputs_tensor is not None: + aidge_core.Log.notice("No support for inputs_tensor argument yet.") + aidge_core.Log.notice(f"No input tensor set for {in_name}, main generated will not be functionnal after code generation.") + else: + aidge_core.Log.notice(f"No input tensor set for {in_name}, main generated will not be functionnal after code generation.") + else: + aidge_core.export_utils.generate_input_file(export_folder=export_folder, array_name=in_name, tensor=input_tensor) + + for out_node, out_id in gv_outputs: + outputs_name.append(f"{out_node.name()}_output_{out_id}") + out_tensor = out_node.get_operator().get_output(out_id) + outputs_dtype.append(data_conversion.aidge2c(out_tensor.dtype())) + outputs_size.append(out_tensor.size()) + + if len(outputs_name) != len(outputs_dtype) or len(outputs_name) != len(outputs_size): + raise RuntimeError("FATAL: Output args list does not have the same length this is an internal bug.") + + ROOT = Path(__file__).resolve().parents[0] + generate_file( + str(Path(export_folder) / "main.cpp"), + str(ROOT / "templates" / "main_benchmark_display_output.jinja"), + func_name="model_forward", + inputs_name=inputs_name, + outputs_name=outputs_name, + outputs_dtype=outputs_dtype, + outputs_size=outputs_size + ) \ No newline at end of file diff --git a/aidge_core/export_utils/templates/main_benchmark_display_output.jinja b/aidge_core/export_utils/templates/main_benchmark_display_output.jinja new file mode 100644 index 000000000..967a28063 --- /dev/null +++ b/aidge_core/export_utils/templates/main_benchmark_display_output.jinja @@ -0,0 +1,45 @@ +#include <cstddef> +#include <cstdio> + +#include "forward.hpp" +{% for name in inputs_name %} +#include "{{ name }}.h" +{% endfor %} + +{% set printf_formats = { + "double": "lf", + "float": "f", + "int8_t": "hhd", + "int16_t": "hd", + "int32_t": "d", + "int64_t": "lld", + "uint8_t": "hhu", + "uint16_t": "hu", + "uint32_t": "u", + "uint64_t": "llu" +} %} + +int main() +{ + // Initialize the output arrays + {%- for o in range(outputs_name | length) %} + {{ outputs_dtype[o] }}* {{ outputs_name[o] }} = nullptr; + // {{ outputs_dtype[o] }}* results_{{ o }} = new {{ outputs_dtype[o] }}[{{ outputs_size[o] }}]; + {% endfor %} + + // Call the forward function + {{ func_name }}({{ inputs_name|join(", ") }}{% if inputs_name %}, {% endif %}&{{ outputs_name|join(", &") }}); + + // Print the results of each output + {%- for o in range(outputs_name | length) %} + for (std::size_t i = 0; i < {{ outputs_size[o] }}; ++i) { + {%- if outputs_dtype[o] in ["double", "float"] %} + std::printf("%.10{{ printf_formats[outputs_dtype[o]] }} ", {{ outputs_name[o] }}[i]); + {%- else %} + std::printf("%{{ printf_formats[outputs_dtype[o]] }} ", {{ outputs_name[o] }}[i]); + {%- endif %} + } + std::printf("\n"); + {% endfor %} + return 0; +} diff --git a/aidge_core/export_utils/templates/main_benchmark_inference_time.jinja b/aidge_core/export_utils/templates/main_benchmark_inference_time.jinja new file mode 100644 index 000000000..0943265e1 --- /dev/null +++ b/aidge_core/export_utils/templates/main_benchmark_inference_time.jinja @@ -0,0 +1,51 @@ + +#include <cstdio> +#include <ctime> + +#include "forward.hpp" +{% for name in inputs_name %} +#include "{{ name }}.h" +{% endfor %} + +{% set printf_formats = { + "double": "%lf", + "float": "%f", + "int8_t": "%hhd", + "int16_t": "%hd", + "int32_t": "%d", + "int64_t": "%lld", + "uint8_t": "%hhu", + "uint16_t": "%hu", + "uint32_t": "%u", + "uint64_t": "%llu" +} %} + +int main() +{ + // Initialize the output arrays + {%- for o in range(outputs_name | length) %} + {{ outputs_dtype[o] }}* {{ outputs_name[o] }} = nullptr; + {% endfor %} + clock_t start; + clock_t end; + double times[{{ nb_iterations }}] = {0}; + for (std::size_t i = 0; i < {{ nb_iterations }} + {{ nb_warmup }}; ++i) { + if (i < {{ nb_warmup }}) { + {{ func_name }}({{ inputs_name|join(", ") }}{% if inputs_name %}, {% endif %}&{{ outputs_name|join(", &") }}); + } else { + start = clock(); + {{ func_name }}({{ inputs_name|join(", ") }}{% if inputs_name %}, {% endif %}&{{ outputs_name|join(", &") }}); + {{ func_name }}({{ inputs_name|join(", ") }}{% if inputs_name %}, {% endif %}&{{ outputs_name|join(", &") }}); + {{ func_name }}({{ inputs_name|join(", ") }}{% if inputs_name %}, {% endif %}&{{ outputs_name|join(", &") }}); + {{ func_name }}({{ inputs_name|join(", ") }}{% if inputs_name %}, {% endif %}&{{ outputs_name|join(", &") }}); + end = clock(); + times[i - {{ nb_warmup }}] = ((double)(end - start)/CLOCKS_PER_SEC)/4.0; + } + } + + for (std::size_t i = 0; i < {{ nb_iterations }}; ++i) { + printf("%.10lf ", times[i]); + } + printf("\n"); + return 0; +} -- GitLab From cd37e0894a11132bf0f957c48b8b4eaa8923010a Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Thu, 3 Apr 2025 09:52:24 +0000 Subject: [PATCH 3/5] update main benchmark script arguments, add warnings and comments about limitations and increase results comparison margins --- benchmark/benchmark.py | 53 ++++++++++++++++++++++++++---------- benchmark/benchmark_torch.py | 7 ++--- 2 files changed, 42 insertions(+), 18 deletions(-) diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index 6726be68a..1061d51f5 100644 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -8,10 +8,11 @@ The configuration is provided via a JSON file. import argparse import copy -import importlib +from importlib import import_module import json import os import sys +from typing import Any import numpy as np import onnx @@ -31,7 +32,7 @@ def load_inference_module(module_name: str): Exits if the module is not installed. """ try: - module = importlib.import_module(module_name) + module = import_module(module_name) print(f"'{module_name}' module successfully imported") return module except ImportError: @@ -73,8 +74,8 @@ def update_test_config( try: extra_attrs = other_parameters[param][str(value)]["attributes"] except KeyError: - print( - f"'{param}': '{value}': 'attributes' - Key not found in other_parameters. Config file may be ill-formed." + ai.Log.error( + f"Test configuration \{'{param}': '{value}'\} no 'attribute' property found. Config file may be ill-formed." ) return None, None attributes.update(extra_attrs) @@ -82,8 +83,8 @@ def update_test_config( try: extra_input_shapes = other_parameters[param][str(value)]["input_shapes"] except KeyError: - print( - f"'{param}': '{value}': 'input_shapes' - Key not found in other_parameters. Config file may be ill-formed." + ai.Log.error( + f"Test configuration \{'{param}': '{value}'\} no 'input_shapes' property found. Config file may be ill-formed." ) return None, None @@ -152,7 +153,8 @@ def compute_output( return benchmark_torch.compute_output(model, {v[0]: v[1] for v in input_data}) else: - model = aidge_onnx.load(model=model) if "aidge" in module_name else model + if "aidge" in module_name: + model = aidge_onnx.load(model=model) return inference_module.benchmark.compute_output(model, input_data) @@ -199,23 +201,29 @@ def main(): "--time", "-t", action="store_true", help="Compute inference time" ) parser.add_argument( - "--save-directory", + "--results-directory", type=str, required=True, help="Directory to save the results", ) + parser.add_argument( + "--results-filename", + type=str, + required=False, + help="Name of the saved result file. If not provided, it will default to the '<operator_name>_<module_to_bench>.json'. If a file with that nae and at tha location already exists, it will be overrided with elements individually replaced only if new ones are computed" + ) args = parser.parse_args() compare_mode = args.compare_with_onnxruntime time_mode = args.time module_name = args.module_to_bench - save_directory = args.save_directory + results_directory = args.results_directory # Load the inference module inference_module = load_inference_module(module_name) # Configure aidge logging - ai.Log.set_console_level(ai.Level.Error) + ai.Log.set_console_level(ai.Level.Warn) ai.Log.set_precision(10) # Load configuration @@ -224,6 +232,14 @@ def main(): opset_version: int = config["opset_version"] initializer_rank: int = config.get("initializer_rank", 1) + test_meta_data: dict[str, Any] = config["test_meta_data"] + if test_meta_data["multiple_batchs"] == True and "export" in module_name: + ai.Log.warn("The tested module seems to be an export module and your test cases contains " + "\033[31;1;multiple\033[0m batchs inputs. This could lead to inaccurate results due to " + "the stream-based (single batch) nature of exports implementations, or an error during " + "export the 'export generation' step. Unless you know what you are doing, you should " + "probably change your configuration file for single batch tests.") + base_input_shapes: list[str, list[int]] = config["base_configuration"][ "input_shapes" ] @@ -252,8 +268,8 @@ def main(): # Initialize or load existing benchmark results results = {"library": "", "compare": {}, "time": {}} - filename: str = f"{operator_aidge.lower()}_{module_name}.json" - results_file_path = os.path.join(save_directory, filename) + filename: str = (args.results_filename + ".json") if args.results_filename else f"{operator_aidge.lower()}_{module_name}.json" + results_file_path = os.path.join(results_directory, filename) # results_file_path = get_results_file_path(module_name, operator_aidge, save_directory) if os.path.exists(results_file_path): with open(results_file_path, "r") as f: @@ -270,6 +286,13 @@ def main(): for value in test_values: print(f"▷ {param} -- {value}") + try: + other_parameters[param][str(value)] + except KeyError: + ai.Log.error( + f"Test configuration {{'{param}': '{value}'}} not found. Config file may be ill-formed." + ) + continue updated_attrs, updated_input_shapes = update_test_config( param, value, @@ -304,12 +327,14 @@ def main(): print(f" └┬─Assessing results are the same as 'onnxruntime' module...") ref = compute_output("onnxruntime", model, input_data, None) tested = compute_output(module_name, model, input_data, inference_module) - + ai.Log.info(f"ref: {ref}\n") + ai.Log.info(f"tested: {tested}\n") if len(ref) > 1: print("Multi-output comparison not handled yet") + print([i.shape for i in ref]) sys.exit(1) results["compare"][param][str(value)] = bool( - np.all(np.isclose(ref, tested)) + np.all(np.isclose(ref, tested, rtol=1e-3, atol=1e-5)) ) print( f" └─[ {'o' if results['compare'][param][str(value)] else 'x'} ]" diff --git a/benchmark/benchmark_torch.py b/benchmark/benchmark_torch.py index 0275daa01..abe92ffe2 100644 --- a/benchmark/benchmark_torch.py +++ b/benchmark/benchmark_torch.py @@ -26,7 +26,6 @@ def measure_inference_time(model_onnx: ModelProto, input_data: dict[str, np.ndar inputs = [torch.tensor(v, device=device) for _, v in input_data.items()] timings = [] - with torch.no_grad(): # Warm-up runs for i in range(nb_warmup + nb_iterations): @@ -59,7 +58,7 @@ def compute_output(model_onnx: ModelProto, input_data: dict[str, np.ndarray]) -> inputs = [torch.tensor(v, device=device) for _, v in input_data.items()] with torch.no_grad(): - outputs = model(*inputs) - outputs = [o.numpy() if isinstance(o, torch.Tensor) else np.array(o) for o in outputs] + # Warning: not tested for multiple outputs case + output = model(*inputs) - return outputs + return output.numpy() -- GitLab From f537a2f8ec5b48564d3bed4e30f29ef32733d15c Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Thu, 3 Apr 2025 12:04:53 +0000 Subject: [PATCH 4/5] add: test config files for some operators (element wise, fc, conv, concat) --- benchmark/operator_config/add_config.json | 160 +++++++++++++++++++ benchmark/operator_config/concat_config.json | 86 ++++++++++ benchmark/operator_config/config.json | 64 ++++++++ benchmark/operator_config/conv2d_config.json | 112 +++++++++++++ benchmark/operator_config/div_onfig.json | 160 +++++++++++++++++++ benchmark/operator_config/fc_config.json | 49 ++++++ benchmark/operator_config/mul_config.json | 160 +++++++++++++++++++ benchmark/operator_config/relu_config.json | 59 +++++++ benchmark/operator_config/sub_config.json | 160 +++++++++++++++++++ 9 files changed, 1010 insertions(+) create mode 100644 benchmark/operator_config/add_config.json create mode 100644 benchmark/operator_config/concat_config.json create mode 100644 benchmark/operator_config/config.json create mode 100644 benchmark/operator_config/conv2d_config.json create mode 100644 benchmark/operator_config/div_onfig.json create mode 100644 benchmark/operator_config/fc_config.json create mode 100644 benchmark/operator_config/mul_config.json create mode 100644 benchmark/operator_config/relu_config.json create mode 100644 benchmark/operator_config/sub_config.json diff --git a/benchmark/operator_config/add_config.json b/benchmark/operator_config/add_config.json new file mode 100644 index 000000000..b0f9424d3 --- /dev/null +++ b/benchmark/operator_config/add_config.json @@ -0,0 +1,160 @@ +{ + "operator": "Add", + "opset_version": 21, + "initializer_rank": 2, + "test_meta_data": { + "multiple_batchs": true + }, + "base_configuration": { + "input_shapes": [ + ["input_0", [64, 64, 64, 64]], + ["input_1", [64, 64, 64, 64]] + ], + "attributes": {} + }, + "test_configuration": { + "main_parameters": { + "dim size": [ + 1,4,16,64,128 + ], + "one dim broadcasted (idx)": [ + 0, 1, 2, 3 + ], + "two dims broadcasted (idx)": [ + [0,1], [0,2], [0,3], [1,2], [1,3], [2,3] + ], + "nb missing axis 1st input": [ + 1, 2, 3, 4 + ] + }, + "other_parameters": { + "dim size": { + "1": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 1, 1, 1]], + ["input_1", [1, 1, 1, 1]] + ] + }, + "4": { + "attributes": {}, + "input_shapes": [ + ["input_0", [4, 4, 4, 4]], + ["input_1", [4, 4, 4, 4]] + ] + }, + "16": { + "attributes": {}, + "input_shapes": [ + ["input_0", [16, 16, 16, 16]], + ["input_1", [16, 16, 16, 16]] + ] + }, + "64": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64, 64, 64]], + ["input_1", [64, 64, 64, 64]] + ] + }, + "128": { + "attributes": {}, + "input_shapes": [ + ["input_0", [128, 128, 128, 128]], + ["input_1", [128, 128, 128, 128]] + ] + } + }, + "one dim broadcasted (idx)": { + "0": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 64, 64, 64]] + ] + }, + "1": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 1, 64, 64]] + ] + }, + "2": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64, 1, 64]] + ] + }, + "3": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64, 64, 1]] + ] + } + }, + "two dims broadcasted (idx)": { + "[0, 1]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 1, 64, 64]] + ] + }, + "[0, 2]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 64, 1, 64]] + ] + }, + "[0, 3]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 64, 64, 1]] + ] + }, + "[1, 2]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 1, 1, 64]] + ] + }, + "[1, 3]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 1, 64, 1]] + ] + }, + "[2, 3]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64, 1, 1]] + ] + } + }, + "nb missing axis 1st input": { + "1": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64, 64]] + ] + }, + "2": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64]] + ] + }, + "3": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64]] + ] + }, + "4": { + "attributes": {}, + "input_shapes": [ + ["input_0", []] + ] + } + } + } + } +} \ No newline at end of file diff --git a/benchmark/operator_config/concat_config.json b/benchmark/operator_config/concat_config.json new file mode 100644 index 000000000..344d77a7c --- /dev/null +++ b/benchmark/operator_config/concat_config.json @@ -0,0 +1,86 @@ +{ + "operator": "Concat", + "opset_version": 13, + "initializer_rank": 3, + "test_meta_data": { + "multiple_batchs": false + }, + "base_configuration": { + "input_shapes": [ + ["input_0", [1, 64, 64, 64]], + ["input_1", [1, 64, 64, 64]], + ["input_2", [1, 64, 64, 64]] + ], + "attributes": { + "axis": 1 + } + }, + "test_configuration": { + "main_parameters": { + "axis": [ + 0,1,2,3 + ], + "dims size": [ + 1, 4, 16, 64, 128 + ] + }, + "other_parameters": { + "axis": { + "0": { + "attributes": {}, + "input_shapes": [] + }, + "1": { + "attributes": {}, + "input_shapes": [] + }, + "2": { + "attributes": {}, + "input_shapes": [] + }, + "3": { + "attributes": {}, + "input_shapes": [] + } + }, + "dims size": { + "1": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 1, 1, 1]], + ["input_1", [1, 1, 1, 1]], + ["input_2", [1, 1, 1, 1]] + ] + }, + "4": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 4, 4, 4]], + ["input_1", [1, 4, 4, 4]], + ["input_2", [1, 4, 4, 4]] + ] + }, + "16": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 16, 16, 16]], + ["input_1", [1, 16, 16, 16]], + ["input_2", [1, 16, 16, 16]] + ] + }, + "64": { + "attributes": {}, + "input_shapes": [] + }, + "256": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 128, 128, 128]], + ["input_1", [1, 128, 128, 128]], + ["input_2", [1, 128, 128, 128]] + ] + } + } + } + } +} \ No newline at end of file diff --git a/benchmark/operator_config/config.json b/benchmark/operator_config/config.json new file mode 100644 index 000000000..cd471787e --- /dev/null +++ b/benchmark/operator_config/config.json @@ -0,0 +1,64 @@ +{ + "operator": "onnx_operator_type", + "opset_version": 21, + "initializer_rank": 2, + "test_meta_data": { + "multiple_batchs": false + }, + "base_configuration": { + "input_shapes": [ + ["input_0", [1, 10, 10, 10]], + ["input_1", [1, 10, 10, 10]] + ], + "attributes": { + "attributes_name_0": 0, + "attributes_name_1": 0 + } + }, + "test_configuration": { + "main_parameters": { + "tested_param_0": [ + 1, 2, 3 + ], + "tested_param_1": [ + [1, 1], + [3, 3], + [5, 5] + ] + }, + "other_parameters": { + "tested_param_0": { + "1": { + "attributes": { + "attributes_name_0": 1 + }, + "input_shapes": [ + ["input_0", [1, 10, 10, 10]] + ] + }, + "2": {}, + "3": {} + }, + "tested_param_1": { + "[1, 1]": { + "attributes": {}, + "input_shapes": [ + ["input_1", [1, 10, 1, 1]] + ] + }, + "[3, 3]": { + "attributes": {}, + "input_shapes": [ + ["input_1", [1, 10, 3, 3]] + ] + }, + "[5, 5]": { + "attributes": {}, + "input_shapes": [ + ["input_1", [1, 10, 5, 5]] + ] + } + } + } + } +} \ No newline at end of file diff --git a/benchmark/operator_config/conv2d_config.json b/benchmark/operator_config/conv2d_config.json new file mode 100644 index 000000000..bf79dc90e --- /dev/null +++ b/benchmark/operator_config/conv2d_config.json @@ -0,0 +1,112 @@ +{ + "operator": "Conv", + "opset_version": 21, + "initializer_rank": 1, + "test_meta_data": { + "multiple_batchs": false + }, + "base_configuration": { + "input_shapes": [ + ["input_0", [1, 10, 200, 200]], + ["weight_1", [10, 10, 3, 3]], + ["bias_2", [10]] + ], + "attributes": { + "kernel_shape": [3, 3], + "strides": [1, 1], + "dilations": [1, 1] + } + }, + "test_configuration": { + "main_parameters": { + "feature_map_size": [ + 10,100,500 + ], + "kernel_shape": [ + [1, 1], + [3, 3], + [5, 5] + ], + "strides": [ + [1, 1], + [2, 2], + [3, 3] + ], + "dilations": [ + [1, 1], + [2, 2], + [3, 3] + ] + }, + "other_parameters": { + "feature_map_size": { + "10": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 10, 10, 10]] + ] + }, + "100": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 10, 100, 100]] + ] + }, + "500": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 10, 500, 500]] + ] + } + }, + "kernel_shape": { + "[1, 1]": { + "attributes": {}, + "input_shapes": [ + ["weight_1", [10, 10, 1, 1]] + ] + }, + "[3, 3]": { + "attributes": {}, + "input_shapes": [ + ["weight_1", [10, 10, 3, 3]] + ] + }, + "[5, 5]": { + "attributes": {}, + "input_shapes": [ + ["weight_1", [10, 10, 5, 5]] + ] + } + }, + "strides": { + "[1, 1]": { + "attributes": {}, + "input_shapes": [] + }, + "[2, 2]": { + "attributes": {}, + "input_shapes": [] + }, + "[3, 3]": { + "attributes": {}, + "input_shapes": [] + } + }, + "dilations": { + "[1, 1]": { + "attributes": {}, + "input_shapes": [] + }, + "[2, 2]": { + "attributes": {}, + "input_shapes": [] + }, + "[3, 3]": { + "attributes": {}, + "input_shapes": [] + } + } + } + } +} \ No newline at end of file diff --git a/benchmark/operator_config/div_onfig.json b/benchmark/operator_config/div_onfig.json new file mode 100644 index 000000000..446715ad6 --- /dev/null +++ b/benchmark/operator_config/div_onfig.json @@ -0,0 +1,160 @@ +{ + "operator": "Div", + "opset_version": 21, + "initializer_rank": 2, + "test_meta_data": { + "multiple_batchs": true + }, + "base_configuration": { + "input_shapes": [ + ["input_0", [64, 64, 64, 64]], + ["input_1", [64, 64, 64, 64]] + ], + "attributes": {} + }, + "test_configuration": { + "main_parameters": { + "dim size": [ + 1,4,16,64,128 + ], + "one dim broadcasted (idx)": [ + 0, 1, 2, 3 + ], + "two dims broadcasted (idx)": [ + [0,1], [0,2], [0,3], [1,2], [1,3], [2,3] + ], + "nb missing axis 1st input": [ + 1, 2, 3, 4 + ] + }, + "other_parameters": { + "dim size": { + "1": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 1, 1, 1]], + ["input_1", [1, 1, 1, 1]] + ] + }, + "4": { + "attributes": {}, + "input_shapes": [ + ["input_0", [4, 4, 4, 4]], + ["input_1", [4, 4, 4, 4]] + ] + }, + "16": { + "attributes": {}, + "input_shapes": [ + ["input_0", [16, 16, 16, 16]], + ["input_1", [16, 16, 16, 16]] + ] + }, + "64": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64, 64, 64]], + ["input_1", [64, 64, 64, 64]] + ] + }, + "128": { + "attributes": {}, + "input_shapes": [ + ["input_0", [128, 128, 128, 128]], + ["input_1", [128, 128, 128, 128]] + ] + } + }, + "one dim broadcasted (idx)": { + "0": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 64, 64, 64]] + ] + }, + "1": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 1, 64, 64]] + ] + }, + "2": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64, 1, 64]] + ] + }, + "3": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64, 64, 1]] + ] + } + }, + "two dims broadcasted (idx)": { + "[0, 1]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 1, 64, 64]] + ] + }, + "[0, 2]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 64, 1, 64]] + ] + }, + "[0, 3]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 64, 64, 1]] + ] + }, + "[1, 2]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 1, 1, 64]] + ] + }, + "[1, 3]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 1, 64, 1]] + ] + }, + "[2, 3]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64, 1, 1]] + ] + } + }, + "nb missing axis 1st input": { + "1": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64, 64]] + ] + }, + "2": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64]] + ] + }, + "3": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64]] + ] + }, + "4": { + "attributes": {}, + "input_shapes": [ + ["input_0", []] + ] + } + } + } + } +} \ No newline at end of file diff --git a/benchmark/operator_config/fc_config.json b/benchmark/operator_config/fc_config.json new file mode 100644 index 000000000..6a401c84a --- /dev/null +++ b/benchmark/operator_config/fc_config.json @@ -0,0 +1,49 @@ +{ + "operator": "Gemm", + "opset_version": 21, + "initializer_rank": 1, + "test_meta_data": { + "multiple_batchs": false + }, + "base_configuration": { + "input_shapes": [ + ["input_0", [1, 100]], + ["weight_1", [100, 50]], + ["bias_2", [50]] + ], + "attributes": { + } + }, + "test_configuration": { + "main_parameters": { + "input_size_0": [ + 10, 100, 1000 + ] + }, + "other_parameters": { + "input_size_0": { + "10": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 10]], + ["weight_1", [10, 50]] + ] + }, + "100": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 100]], + ["weight_1", [100, 50]] + ] + }, + "1000": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 1000]], + ["weight_1", [1000, 50]] + ] + } + } + } + } +} \ No newline at end of file diff --git a/benchmark/operator_config/mul_config.json b/benchmark/operator_config/mul_config.json new file mode 100644 index 000000000..3c89d5671 --- /dev/null +++ b/benchmark/operator_config/mul_config.json @@ -0,0 +1,160 @@ +{ + "operator": "Mul", + "opset_version": 21, + "initializer_rank": 2, + "test_meta_data": { + "multiple_batchs": true + }, + "base_configuration": { + "input_shapes": [ + ["input_0", [64, 64, 64, 64]], + ["input_1", [64, 64, 64, 64]] + ], + "attributes": {} + }, + "test_configuration": { + "main_parameters": { + "dim size": [ + 1,4,16,64,128 + ], + "one dim broadcasted (idx)": [ + 0, 1, 2, 3 + ], + "two dims broadcasted (idx)": [ + [0,1], [0,2], [0,3], [1,2], [1,3], [2,3] + ], + "nb missing axis 1st input": [ + 1, 2, 3, 4 + ] + }, + "other_parameters": { + "dim size": { + "1": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 1, 1, 1]], + ["input_1", [1, 1, 1, 1]] + ] + }, + "4": { + "attributes": {}, + "input_shapes": [ + ["input_0", [4, 4, 4, 4]], + ["input_1", [4, 4, 4, 4]] + ] + }, + "16": { + "attributes": {}, + "input_shapes": [ + ["input_0", [16, 16, 16, 16]], + ["input_1", [16, 16, 16, 16]] + ] + }, + "64": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64, 64, 64]], + ["input_1", [64, 64, 64, 64]] + ] + }, + "128": { + "attributes": {}, + "input_shapes": [ + ["input_0", [128, 128, 128, 128]], + ["input_1", [128, 128, 128, 128]] + ] + } + }, + "one dim broadcasted (idx)": { + "0": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 64, 64, 64]] + ] + }, + "1": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 1, 64, 64]] + ] + }, + "2": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64, 1, 64]] + ] + }, + "3": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64, 64, 1]] + ] + } + }, + "two dims broadcasted (idx)": { + "[0, 1]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 1, 64, 64]] + ] + }, + "[0, 2]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 64, 1, 64]] + ] + }, + "[0, 3]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 64, 64, 1]] + ] + }, + "[1, 2]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 1, 1, 64]] + ] + }, + "[1, 3]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 1, 64, 1]] + ] + }, + "[2, 3]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64, 1, 1]] + ] + } + }, + "nb missing axis 1st input": { + "1": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64, 64]] + ] + }, + "2": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64]] + ] + }, + "3": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64]] + ] + }, + "4": { + "attributes": {}, + "input_shapes": [ + ["input_0", []] + ] + } + } + } + } +} \ No newline at end of file diff --git a/benchmark/operator_config/relu_config.json b/benchmark/operator_config/relu_config.json new file mode 100644 index 000000000..f97b85f5f --- /dev/null +++ b/benchmark/operator_config/relu_config.json @@ -0,0 +1,59 @@ +{ + "operator": "Relu", + "opset_version": 14, + "initializer_rank": 1, + "test_meta_data": { + "multiple_batchs": true + }, + "base_configuration": { + "input_shapes": [ + ["input_0", [64, 64, 64, 64]] + ], + "attributes": {} + }, + "test_configuration": { + "main_parameters": { + "dims size": [ + 1, 4, 16, 32, 64, 128 + ] + }, + "other_parameters": { + "dims size": { + "1": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 1, 1, 1]] + ] + }, + "4": { + "attributes": {}, + "input_shapes": [ + ["input_0", [4, 4, 4, 4]] + ] + }, + "16": { + "attributes": {}, + "input_shapes": [ + ["input_0", [16, 16, 16, 16]] + ] + }, + "32": { + "attributes": {}, + "input_shapes": [ + ["input_0", [32, 32, 32, 32]] + ] + }, + "64": { + "attributes": {}, + "input_shapes": [] + }, + "128": { + "attributes": {}, + "input_shapes": [ + ["input_0", [128, 128, 128, 128]] + ] + } + } + } + } +} \ No newline at end of file diff --git a/benchmark/operator_config/sub_config.json b/benchmark/operator_config/sub_config.json new file mode 100644 index 000000000..17359793a --- /dev/null +++ b/benchmark/operator_config/sub_config.json @@ -0,0 +1,160 @@ +{ + "operator": "Sub", + "opset_version": 21, + "initializer_rank": 2, + "test_meta_data": { + "multiple_batchs": true + }, + "base_configuration": { + "input_shapes": [ + ["input_0", [64, 64, 64, 64]], + ["input_1", [64, 64, 64, 64]] + ], + "attributes": {} + }, + "test_configuration": { + "main_parameters": { + "dim size": [ + 1,4,16,64,128 + ], + "one dim broadcasted (idx)": [ + 0, 1, 2, 3 + ], + "two dims broadcasted (idx)": [ + [0,1], [0,2], [0,3], [1,2], [1,3], [2,3] + ], + "nb missing axis 1st input": [ + 1, 2, 3, 4 + ] + }, + "other_parameters": { + "dim size": { + "1": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 1, 1, 1]], + ["input_1", [1, 1, 1, 1]] + ] + }, + "4": { + "attributes": {}, + "input_shapes": [ + ["input_0", [4, 4, 4, 4]], + ["input_1", [4, 4, 4, 4]] + ] + }, + "16": { + "attributes": {}, + "input_shapes": [ + ["input_0", [16, 16, 16, 16]], + ["input_1", [16, 16, 16, 16]] + ] + }, + "64": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64, 64, 64]], + ["input_1", [64, 64, 64, 64]] + ] + }, + "128": { + "attributes": {}, + "input_shapes": [ + ["input_0", [128, 128, 128, 128]], + ["input_1", [128, 128, 128, 128]] + ] + } + }, + "one dim broadcasted (idx)": { + "0": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 64, 64, 64]] + ] + }, + "1": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 1, 64, 64]] + ] + }, + "2": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64, 1, 64]] + ] + }, + "3": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64, 64, 1]] + ] + } + }, + "two dims broadcasted (idx)": { + "[0, 1]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 1, 64, 64]] + ] + }, + "[0, 2]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 64, 1, 64]] + ] + }, + "[0, 3]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [1, 64, 64, 1]] + ] + }, + "[1, 2]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 1, 1, 64]] + ] + }, + "[1, 3]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 1, 64, 1]] + ] + }, + "[2, 3]": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64, 1, 1]] + ] + } + }, + "nb missing axis 1st input": { + "1": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64, 64]] + ] + }, + "2": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64, 64]] + ] + }, + "3": { + "attributes": {}, + "input_shapes": [ + ["input_0", [64]] + ] + }, + "4": { + "attributes": {}, + "input_shapes": [ + ["input_0", []] + ] + } + } + } + } +} \ No newline at end of file -- GitLab From 3adc3fd7283689d05a36efa289f98236f2a13722 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Thu, 3 Apr 2025 12:05:30 +0000 Subject: [PATCH 5/5] change plot color --- benchmark/generate_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/generate_graph.py b/benchmark/generate_graph.py index 720755225..82f61d5c8 100644 --- a/benchmark/generate_graph.py +++ b/benchmark/generate_graph.py @@ -132,7 +132,7 @@ def create_relative_difference_plots(test_parameters: dict, ref_times: dict, ref x=f'{param_name}', y='Ratio', hue='Library', - palette='viridis', + palette='hls', ax=ax, errorbar=None ) -- GitLab