From 5d9072be15594fbeec1e5a58854ad4662e136c32 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Fri, 7 Mar 2025 01:52:04 +0000
Subject: [PATCH 1/5] [Add][WIP] benchmark scripts

- main performance measurement script
- specific functions for measuring onnxruntime and torch libraries
- comparison graph generation script
---
 benchmark/benchmark.py             | 326 +++++++++++++++++++++++++++++
 benchmark/benchmark_onnxruntime.py |  39 ++++
 benchmark/benchmark_torch.py       |  65 ++++++
 benchmark/generate_graph.py        | 256 ++++++++++++++++++++++
 4 files changed, 686 insertions(+)
 create mode 100644 benchmark/benchmark.py
 create mode 100644 benchmark/benchmark_onnxruntime.py
 create mode 100644 benchmark/benchmark_torch.py
 create mode 100644 benchmark/generate_graph.py

diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py
new file mode 100644
index 000000000..6726be68a
--- /dev/null
+++ b/benchmark/benchmark.py
@@ -0,0 +1,326 @@
+"""
+Operator Kernel Performance Benchmarking Script
+
+This script benchmarks operator kernels using a specified inference module.
+It supports timing measurements and (optionally) output comparisons with ONNXRuntime.
+The configuration is provided via a JSON file.
+"""
+
+import argparse
+import copy
+import importlib
+import json
+import os
+import sys
+
+import numpy as np
+import onnx
+
+import aidge_core as ai
+import aidge_onnx
+from aidge_onnx.generate_singleop_onnx import create_onnx_model
+
+# Benchmark parameters
+NB_WARMUPS: int = 10
+NB_ITERATIONS: int = 50
+
+
+def load_inference_module(module_name: str):
+    """
+    Dynamically imports and returns the inference module.
+    Exits if the module is not installed.
+    """
+    try:
+        module = importlib.import_module(module_name)
+        print(f"'{module_name}' module successfully imported")
+        return module
+    except ImportError:
+        print(
+            f"Error: {module_name} is not installed. Install it using 'pip install {module_name}'."
+        )
+        sys.exit(1)
+
+
+def load_config(config_file: str) -> dict:
+    """
+    Loads and returns the JSON configuration from the given file.
+    """
+    with open(config_file, "r") as f:
+        return json.load(f)
+
+
+def update_test_config(
+    param: str,
+    value,
+    base_attributes: dict,
+    base_input_shapes: list,
+    other_parameters: dict,
+    operator_attributes: list,
+):
+    """
+    Updates the operator attributes and input shapes based on the test parameter.
+
+    Returns:
+        tuple: (updated_attributes, updated_input_shapes) or (None, None) if keys are missing.
+    """
+    attributes = copy.deepcopy(base_attributes)
+    input_shapes = copy.deepcopy(base_input_shapes)
+
+    # Update if the parameter is a valid operator attribute
+    if param in operator_attributes:
+        attributes[param] = value
+
+    try:
+        extra_attrs = other_parameters[param][str(value)]["attributes"]
+    except KeyError:
+        print(
+            f"'{param}': '{value}': 'attributes' - Key not found in other_parameters. Config file may be ill-formed."
+        )
+        return None, None
+    attributes.update(extra_attrs)
+
+    try:
+        extra_input_shapes = other_parameters[param][str(value)]["input_shapes"]
+    except KeyError:
+        print(
+            f"'{param}': '{value}': 'input_shapes' - Key not found in other_parameters. Config file may be ill-formed."
+        )
+        return None, None
+
+    for shape_update in extra_input_shapes:
+        name, new_shape = shape_update
+        for base_shape in input_shapes:
+            if base_shape[0] == name:
+                base_shape[1] = new_shape
+                break
+
+    return attributes, input_shapes
+
+
+# def get_results_file_path(module_name: str, operator_aidge: str, save_directory: str) -> str:
+#     """
+#     Constructs and returns the file path for saving the benchmark results.
+#     """
+#     if module_name == "onnxruntime":
+#         filename = f"{operator_aidge}_onnxruntime.json"
+#     elif module_name == "pytorch":
+#         filename = f"{operator_aidge}_pytorch.json"
+#     else:
+#         filename = f"{operator_aidge}.json"
+#     return os.path.join(save_directory, filename)
+
+
+def measure_inference_time(
+    module_name: str, model: onnx.ModelProto, input_data, inference_module=None
+) -> list[float]:
+    """
+    Measures inference time using the appropriate benchmark function.
+    """
+    if module_name == "onnxruntime":
+        import benchmark_onnxruntime
+
+        return benchmark_onnxruntime.measure_inference_time(
+            model, {v[0]: v[1] for v in input_data}, NB_WARMUPS, NB_ITERATIONS
+        )
+    elif module_name == "torch":
+        import benchmark_torch
+
+        return benchmark_torch.measure_inference_time(
+            model, {v[0]: v[1] for v in input_data}, NB_WARMUPS, NB_ITERATIONS
+        )
+    else:
+        model = aidge_onnx.load(model=model) if "aidge" in module_name else model
+        return inference_module.benchmark.measure_inference_time(
+            model, input_data, NB_WARMUPS, NB_ITERATIONS
+        )
+
+
+def compute_output(
+    module_name: str, model: onnx.ModelProto, input_data, inference_module
+) -> list[np.ndarray]:
+    """
+    Measures inference time using the appropriate benchmark function.
+    """
+    if module_name == "onnxruntime":
+        import benchmark_onnxruntime
+
+        return benchmark_onnxruntime.compute_output(
+            model, {v[0]: v[1] for v in input_data}
+        )
+    elif module_name == "torch":
+        import benchmark_torch
+
+        return benchmark_torch.compute_output(model, {v[0]: v[1] for v in input_data})
+    else:
+        model = aidge_onnx.load(model=model) if "aidge" in module_name else model
+        return inference_module.benchmark.compute_output(model, input_data)
+
+
+def prepare_input_data(
+    input_shapes: list[str, list[int]], initializer_rank: int
+) -> list[str, np.ndarray]:
+    """
+    Generates random input data for the first `initializer_rank` inputs.
+    """
+    data: list[str, np.ndarray] = []
+    for i, conf in enumerate(input_shapes):
+        name, shape = conf
+        if i < initializer_rank:
+            random_array = np.array(np.random.rand(*shape)).astype(np.float32)
+            data.append((name, random_array))
+    return data
+
+
+def main():
+    parser = argparse.ArgumentParser(
+        description="Operator Kernel Performance Benchmarking"
+    )
+    parser.add_argument(
+        "--config-file",
+        "-cf",
+        type=str,
+        required=True,
+        help="Path to configuration JSON with operator type, attributes, and input sizes.",
+    )
+    parser.add_argument(
+        "--module-to-bench",
+        "-mtb",
+        type=str,
+        required=True,
+        help="Name of the module containing the inference functions",
+    )
+    parser.add_argument(
+        "--compare-with-onnxruntime",
+        "-cwo",
+        action="store_true",
+        help="Compare output with ONNXRuntime",
+    )
+    parser.add_argument(
+        "--time", "-t", action="store_true", help="Compute inference time"
+    )
+    parser.add_argument(
+        "--save-directory",
+        type=str,
+        required=True,
+        help="Directory to save the results",
+    )
+    args = parser.parse_args()
+
+    compare_mode = args.compare_with_onnxruntime
+    time_mode = args.time
+    module_name = args.module_to_bench
+    save_directory = args.save_directory
+
+    # Load the inference module
+    inference_module = load_inference_module(module_name)
+
+    # Configure aidge logging
+    ai.Log.set_console_level(ai.Level.Error)
+    ai.Log.set_precision(10)
+
+    # Load configuration
+    config = load_config(args.config_file)
+    operator_name: str = config["operator"]
+    opset_version: int = config["opset_version"]
+    initializer_rank: int = config.get("initializer_rank", 1)
+
+    base_input_shapes: list[str, list[int]] = config["base_configuration"][
+        "input_shapes"
+    ]
+    base_attributes: dict = config["base_configuration"].get("attributes", {})
+
+    main_parameters: dict[str, Any] = config["test_configuration"].get(
+        "main_parameters", {}
+    )
+    other_parameters: dict[str, dict] = config["test_configuration"].get(
+        "other_parameters", {}
+    )
+
+    # Get operator attributes from the schema for filtering test parameters
+    operator_schema = onnx.defs.get_schema(operator_name, opset_version)
+    operator_attributes: list[str] = list(operator_schema.attributes)
+
+    # Create a base ONNX model to determine the operator type for naming the results file
+    base_model: onnx.ModelProto = create_onnx_model(
+        operator_name,
+        opset_version,
+        base_input_shapes,
+        initializer_rank,
+        **base_attributes,
+    )
+    operator_aidge: str = list(aidge_onnx.load(model=base_model).get_ordered_outputs())[0][0].type()
+
+    # Initialize or load existing benchmark results
+    results = {"library": "", "compare": {}, "time": {}}
+    filename: str = f"{operator_aidge.lower()}_{module_name}.json"
+    results_file_path = os.path.join(save_directory, filename)
+    # results_file_path = get_results_file_path(module_name, operator_aidge, save_directory)
+    if os.path.exists(results_file_path):
+        with open(results_file_path, "r") as f:
+            results = json.load(f)
+    results["library"] = module_name
+
+    # Loop over each test parameter and its values
+    print("Starting tests...")
+    for param, test_values in main_parameters.items():
+        if time_mode:
+            results["time"].setdefault(param, {})
+        if compare_mode:
+            results["compare"].setdefault(param, {})
+
+        for value in test_values:
+            print(f"â–· {param} -- {value}")
+            updated_attrs, updated_input_shapes = update_test_config(
+                param,
+                value,
+                base_attributes,
+                base_input_shapes,
+                other_parameters,
+                operator_attributes,
+            )
+            if updated_attrs is None or updated_input_shapes is None:
+                continue  # Skip this test case if configuration is missing
+
+            # Create the updated ONNX model
+            model = create_onnx_model(
+                operator_name,
+                opset_version,
+                updated_input_shapes,
+                initializer_rank,
+                **updated_attrs,
+            )
+
+            input_data = prepare_input_data(updated_input_shapes, initializer_rank)
+
+            if time_mode:
+                print(f" {'├' if compare_mode else '└'}┬─Measuring kernel inference time...")
+                timing = measure_inference_time(module_name, model, input_data, inference_module)
+                results["time"][param][str(value)] = timing
+                print(
+                f" {'│' if compare_mode else ' '}└─[ time = {np.array(timing).mean():.2e} ± {np.array(timing).std():.2e} seconds ]"
+                )
+
+            if compare_mode:
+                print(f" └┬─Assessing results are the same as 'onnxruntime' module...")
+                ref = compute_output("onnxruntime", model, input_data, None)
+                tested = compute_output(module_name, model, input_data, inference_module)
+
+                if len(ref) > 1:
+                    print("Multi-output comparison not handled yet")
+                    sys.exit(1)
+                results["compare"][param][str(value)] = bool(
+                    np.all(np.isclose(ref, tested))
+                )
+                print(
+                    f"  └─[ {'o' if results['compare'][param][str(value)] else 'x'} ]"
+                )
+            print()
+
+    # Save results
+    print(f"Printing results to JSON '{results_file_path}'")
+    with open(results_file_path, "w") as outfile:
+        json.dump(results, outfile, indent=4)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/benchmark/benchmark_onnxruntime.py b/benchmark/benchmark_onnxruntime.py
new file mode 100644
index 000000000..cb7336bcf
--- /dev/null
+++ b/benchmark/benchmark_onnxruntime.py
@@ -0,0 +1,39 @@
+import numpy as np
+import onnxruntime as ort
+from onnx import ModelProto
+import time
+
+def measure_inference_time(model: ModelProto, input_data: dict[str, np.ndarray], nb_warmup: int = 10, nb_iterations: int = 50) -> list[float]:
+    """
+    Run the provided ONNX model using ONNXRuntime.
+    Performs 10 warm-up runs followed by 50 timed runs (using CPU process time).
+
+    Args:
+        model: The ONNX model (ModelProto).
+        input_data: Dictionary mapping all input names to NumPy arrays.
+
+    Returns:
+        List of CPU times (in seconds) for the 50 timed runs.
+    """
+    sess_opt = ort.SessionOptions()
+    sess_opt.intra_op_num_threads = 1
+    sess = ort.InferenceSession(model.SerializeToString(), sess_opt)
+
+    timings = []
+    # Warm-up runs.
+    for i in range(nb_warmup + nb_iterations):
+        if i < nb_warmup:
+            sess.run(None, input_data)
+        else:
+            start = time.process_time()
+            sess.run(None, input_data)
+            end = time.process_time()
+            timings.append((end - start))
+    return timings
+
+def compute_output(model: ModelProto, input_data: dict[str, np.ndarray]) -> list[np.ndarray]:
+    sess = ort.InferenceSession(model.SerializeToString())
+    # Run the session with the provided input_data.
+    outputs = sess.run(None, input_data)
+    # Return all outputs.
+    return np.array(outputs)
\ No newline at end of file
diff --git a/benchmark/benchmark_torch.py b/benchmark/benchmark_torch.py
new file mode 100644
index 000000000..0275daa01
--- /dev/null
+++ b/benchmark/benchmark_torch.py
@@ -0,0 +1,65 @@
+import torch
+import numpy as np
+from onnx import ModelProto
+from onnx2torch import convert
+import time
+
+def measure_inference_time(model_onnx: ModelProto, input_data: dict[str, np.ndarray], nb_warmup: int = 10, nb_iterations: int = 50) -> list[float]:
+    """
+    Run the provided PyTorch model.
+    Performs 10 warm-up runs followed by 50 timed runs (using CPU process time).
+
+    Args:
+        model_onnx: The ONNX model.
+        input_data: Dictionary mapping all input names to NumPy arrays.
+
+    Returns:
+        List of CPU times (in seconds) for the 50 timed runs.
+    """
+    model = convert(model_onnx)
+
+    device = torch.device("cpu")
+    model.to(device)
+    model.eval()
+
+    torch.set_num_threads(1)
+
+    inputs = [torch.tensor(v, device=device) for _, v in input_data.items()]
+    timings = []
+
+    with torch.no_grad():
+        # Warm-up runs
+        for i in range(nb_warmup + nb_iterations):
+            if i < nb_warmup:
+                model(*inputs)
+            else:
+                start = time.process_time()
+                model(*inputs)
+                end = time.process_time()
+                timings.append(end - start)
+    return timings
+
+def compute_output(model_onnx: ModelProto, input_data: dict[str, np.ndarray]) -> list[np.ndarray]:
+    """
+    Run the PyTorch model inference.
+
+    Args:
+        model: The PyTorch model.
+        input_data: Dictionary mapping all input names to NumPy arrays.
+
+    Returns:
+        The first output tensor if there is only one, else a list of output tensors.
+    """
+    model = convert(model_onnx)
+
+    device = torch.device("cpu")
+    model.to(device)
+    model.eval()
+
+    inputs = [torch.tensor(v, device=device) for _, v in input_data.items()]
+
+    with torch.no_grad():
+        outputs = model(*inputs)
+    outputs = [o.numpy() if isinstance(o, torch.Tensor) else np.array(o) for o in outputs]
+
+    return outputs
diff --git a/benchmark/generate_graph.py b/benchmark/generate_graph.py
new file mode 100644
index 000000000..720755225
--- /dev/null
+++ b/benchmark/generate_graph.py
@@ -0,0 +1,256 @@
+import argparse
+import json
+import sys
+import textwrap
+import numpy as np
+import pandas as pd
+import matplotlib.pyplot as plt
+import seaborn as sns
+
+# Set a default style
+sns.set_theme(style="ticks", palette="flare")
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description="Compare time performance of operator kernels by plotting relative differences."
+    )
+    parser.add_argument(
+        "--operator-config", "-oc", type=str, default="config.json",
+        help="Path to the configuration JSON file (default: config.json)"
+    )
+    parser.add_argument(
+        "--ref", "-r", type=str, required=True,
+        help="Path to the JSON file with reference results"
+    )
+    parser.add_argument(
+        "--libs", "-l", type=str, nargs='+', required=True,
+        help=("Paths to one or more JSON files with library results. For violin/box mode, "
+              "exactly one file must be provided. For bar mode, multiple files can be provided.")
+    )
+    # parser.add_argument(
+    #     "--plot-type", "-pt", type=str, choices=["violin", "box", "bar"], default="violin",
+    #     help="Type of plot to use: 'violin', 'box', or 'bar' (default: violin)"
+    # )
+    return parser.parse_args()
+
+
+def load_json(file_path: str):
+    with open(file_path, 'r') as f:
+        return json.load(f)
+
+
+def create_relative_difference_plots(test_parameters: dict, ref_times: dict, ref_library: str,
+                                       plot_type: str, new_times: dict = None, new_library: str = None,
+                                       libraries: list = None):
+    """
+    Creates subplots comparing relative differences.
+
+    For "violin" and "box" modes, `new_times` and `new_library` are used.
+    For "bar" mode, a list of library tuples (library_name, times) in `libraries` is used.
+    In bar mode the reference library (ref_library) is always added as the baseline (ratio = 1).
+    """
+    n_params = len(test_parameters)
+    n_cols = 1
+    if n_params > 1:
+        if n_params == 2 or n_params == 4:
+            n_cols = 2
+        else:
+            n_cols = 3
+
+    n_rows = (n_params + n_cols - 1) // n_cols
+
+    fig, axes = plt.subplots(n_rows, n_cols, figsize=(10 if n_params == 1 else 15, 5 * n_rows))
+    # Ensure axes is always iterable
+    if n_params == 1:
+        axes = [axes]
+    axes_flat = axes.flatten() if n_params > 1 else axes
+
+    for idx, (param_name, param_values) in enumerate(test_parameters.items()):
+        ax = axes_flat[idx]
+        # if plot_type in ["violin", "box"]:
+        #     # Compute relative differences (%) per run
+        #     plot_data = []
+        #     for val in param_values:
+        #         new_arr = np.array(new_times[param_name][str(val)])
+        #         ref_arr = np.array(ref_times[param_name][str(val)])
+        #         rel_diff = (new_arr - ref_arr) / ref_arr * 100
+        #         plot_data.extend([(str(val), diff) for diff in rel_diff])
+        #     df = pd.DataFrame(plot_data, columns=[f'{param_name} Value', 'Relative Difference (%)'])
+        #     # Optionally filter extreme outliers using IQR
+        #     for col in df.select_dtypes(include='number').columns:
+        #         Q1 = df[col].quantile(0.25)
+        #         Q3 = df[col].quantile(0.75)
+        #         IQR = Q3 - Q1
+        #         lower_bound = Q1 - IQR
+        #         upper_bound = Q3 + IQR
+        #         df = df[(df[col] >= lower_bound) & (df[col] <= upper_bound)]
+        #     if plot_type == "violin":
+        #         sns.violinplot(
+        #             data=df,
+        #             x=f'{param_name} Value',
+        #             y='Relative Difference (%)',
+        #             hue=f'{param_name} Value',
+        #             palette='flare',
+        #             inner='quartile',
+        #             ax=ax
+        #         )
+        #     else:  # box plot
+        #         sns.boxplot(
+        #             data=df,
+        #             x=f'{param_name} Value',
+        #             y='Relative Difference (%)',
+        #             hue=f'{param_name} Value',
+        #             palette='flare',
+        #             ax=ax
+        #         )
+        #     if ax.get_legend() is not None:
+        #         ax.legend_.remove()
+        #     ax.grid(True, axis='y', alpha=0.5, color='gray')
+        #     ax.axhline(y=0, color='red', linewidth=1)
+        #     ax.set_ylim(-30, 150)
+        #     stats_text = (f'Mean: {df["Relative Difference (%)"].mean():.2f}%\n'
+        #                   f'Std: {df["Relative Difference (%)"].std():.2f}%')
+        #     ax.text(0.02, 0.98, stats_text, transform=ax.transAxes,
+        #             verticalalignment='top',
+        #             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
+        # elif plot_type == "bar":
+            # For bar plots: compute the median time ratio for each library compared to reference
+        data = []
+        for val in param_values:
+            data.append((str(val), ref_library, 1.0))  # Reference baseline (ratio = 1)
+            for lib_name, lib_times in libraries:
+                lib_arr = np.array(lib_times[param_name][str(val)])
+                ref_arr = np.array(ref_times[param_name][str(val)])
+                median_lib = np.median(lib_arr)
+                median_ref = np.median(ref_arr)
+                ratio = median_lib / median_ref if median_ref != 0 else np.nan
+                data.append((str(val), lib_name, ratio))
+        df_bar = pd.DataFrame(data, columns=[f'{param_name}', 'Library', 'Ratio'])
+        sns.barplot(
+            data=df_bar,
+            x=f'{param_name}',
+            y='Ratio',
+            hue='Library',
+            palette='viridis',
+            ax=ax,
+            errorbar=None
+        )
+        ax.grid(True, axis='y', alpha=0.5, color='gray')
+        # Annotate bars with their ratio values
+        for container in ax.containers:
+            labels = [f'{h:.2f}' if h > 1e-6 else '' for h in container.datavalues]
+            ax.bar_label(container, labels=labels, padding=3)
+        ax.set_ylim(0, max(df_bar['Ratio'].max() * 1.1, 1.1))
+        # else:
+        #     ax.text(0.5, 0.5, "Unknown plot type", horizontalalignment='center', verticalalignment='center')
+
+    # Remove any unused subplots
+    for idx in range(len(test_parameters), len(axes_flat)):
+        fig.delaxes(axes_flat[idx])
+    if n_params == 1:
+        plt.tight_layout(rect=[0, 0.05, 1, 0.88])
+    else:
+        plt.tight_layout(rect=[0, 0.05, 1, 0.93])
+
+    # Create a common legend (if any) at the top center
+    common_handles, common_labels = None, None
+    for ax in fig.axes:
+        leg = ax.get_legend()
+        if leg is not None:
+            common_handles, common_labels = ax.get_legend_handles_labels()
+            break
+    if common_handles is not None and common_labels is not None:
+        fig.legend(common_handles, common_labels, loc='upper center', ncol=len(common_labels),
+                   bbox_to_anchor=(0.5, 0.99), title="Library", fontsize=14)
+    # Remove legends from individual subplots
+    for ax in axes_flat:
+        if ax.get_legend() is not None:
+            ax.get_legend().remove()
+    return fig
+
+
+def main():
+    args = parse_args()
+    config = load_json(args.operator_config)
+    ref_results = load_json(args.ref)
+    library_files = args.libs
+
+    operator = config["operator"]
+    test_parameters = config["test_configuration"].get("main_parameters", {})
+
+    # Load reference times and library name from reference JSON
+    ref_times = ref_results.get("time")
+    ref_library = ref_results.get("library", "ref_lib")
+    if ref_times is None:
+        print("Reference JSON does not contain time results.")
+        sys.exit(1)
+
+    # if args.plot_type in ["violin", "box"]:
+    #     if len(library_files) != 1:
+    #         print("Error: For violin/box mode, exactly one library JSON file must be provided.")
+    #         sys.exit(1)
+    #     comp_results = load_json(library_files[0])
+    #     comp_times = comp_results.get("time")
+    #     comp_library = comp_results.get("library", "comp_lib")
+    #     if comp_times is None:
+    #         print("Library JSON does not contain time results.")
+    #         sys.exit(1)
+    #     fig = create_relative_difference_plots(
+    #         test_parameters, ref_times, ref_library,
+    #         plot_type=args.plot_type, new_times=comp_times, new_library=comp_library
+    #     )
+    #     filename = f"{operator}_comp_{comp_library}-vs-{ref_library}.svg"
+    # elif args.plot_type == "bar":
+    libraries = []
+    for lib_file in library_files:
+        lib_result = load_json(lib_file)
+        lib_times = lib_result.get("time")
+        lib_name = lib_result.get("library", "lib")
+        if lib_times is None:
+            print(f"Library JSON {lib_file} does not contain time results. Skipping.")
+            continue
+        libraries.append((lib_name, lib_times))
+    if not libraries:
+        print("No valid library results available for bar plot.")
+        sys.exit(1)
+    fig = create_relative_difference_plots(
+        test_parameters, ref_times, ref_library,
+        plot_type="bar", libraries=libraries
+    )
+    # lib_names = "-".join([name for name, _ in libraries])
+    filename = f"{operator}_inference_time_comparison.svg"
+    # else:
+    #     print("Unsupported plot type.")
+    #     sys.exit(1)
+
+    ##############################
+    # Prepare footer texts
+    footer_title = f'[{operator}] kernel relative inference time comparison'
+    default_config = config.get("base_configuration", {})
+
+    # Wrap the default configuration text to a given width.
+    wrapped_config = textwrap.wrap(f'Base configuration: {default_config}', width=160)
+    n_lines = len(wrapped_config)
+    config_text = "\n".join(wrapped_config)
+
+    # Adjust the figure layout to provide extra space at the bottom.
+    if len(test_parameters) == 1:
+        plt.subplots_adjust(bottom=0.2+0.02*n_lines)
+    else:
+        plt.subplots_adjust(bottom=0.14+0.02*n_lines)
+
+    # Add the footer title (bottom center) with fontsize 16.
+    fig.text(0.5, 0.035+n_lines*0.025, footer_title, ha='center', va='bottom', fontsize=18)
+
+    # Add the default configuration text just below the title with the computed fontsize.
+    fig.text(0.5, 0.02, config_text, ha='center', va='bottom', fontsize=12)
+
+    ############################
+    # save
+    plt.savefig(filename)
+    print(f"Plot saved as {filename}")
+
+
+if __name__ == "__main__":
+    main()
-- 
GitLab


From 2432868b646efab56819a117fbed94244c469456 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Thu, 3 Apr 2025 09:39:44 +0000
Subject: [PATCH 2/5] add: main export scripts for timing measurements and
 output display

---
 aidge_core/export_utils/__init__.py           |   2 +-
 aidge_core/export_utils/generate_main.py      | 124 ++++++++++++++++++
 .../main_benchmark_display_output.jinja       |  45 +++++++
 .../main_benchmark_inference_time.jinja       |  51 +++++++
 4 files changed, 221 insertions(+), 1 deletion(-)
 create mode 100644 aidge_core/export_utils/templates/main_benchmark_display_output.jinja
 create mode 100644 aidge_core/export_utils/templates/main_benchmark_inference_time.jinja

diff --git a/aidge_core/export_utils/__init__.py b/aidge_core/export_utils/__init__.py
index 72472eee0..87a43f6a6 100644
--- a/aidge_core/export_utils/__init__.py
+++ b/aidge_core/export_utils/__init__.py
@@ -3,4 +3,4 @@ from .code_generation import generate_file, generate_str, copy_file
 from .export_registry import ExportLib
 from .scheduler_export import scheduler_export
 from .tensor_export import tensor_to_c, generate_input_file
-from .generate_main import generate_main_cpp, generate_main_compare_cpp
+from .generate_main import generate_main_cpp, generate_main_compare_cpp, generate_main_inference_time_cpp, generate_main_display_output_cpp
diff --git a/aidge_core/export_utils/generate_main.py b/aidge_core/export_utils/generate_main.py
index 57fc68bca..288f4926d 100644
--- a/aidge_core/export_utils/generate_main.py
+++ b/aidge_core/export_utils/generate_main.py
@@ -129,3 +129,127 @@ def generate_main_compare_cpp(export_folder: str, graph_view: aidge_core.GraphVi
         outputs_dtype=outputs_dtype,
         outputs_size=outputs_size
     )
+
+def generate_main_inference_time_cpp(export_folder: str, graph_view: aidge_core.GraphView, nb_iterations, nb_warmup, inputs_tensor=None) -> None:
+    """
+    Generate a C++ file to manage the forward pass of a model using the given graph structure.
+
+    This function extracts details from the :py:class:`aidge_core.graph_view` object, including input and output node names, data types,
+    and tensor sizes. It uses this data to populate a C++ file template (`main.jinja`), creating a file (`main.cpp`)
+    that call the `model_forward` function, which handles data flow and processing for the exported model.
+
+    This function also generate files containing input tensor if they have been set.
+
+    :param export_folder: Path to the folder where the generated C++ file (`main.cpp`) will be saved.
+    :type export_folder: str
+    :param graph_view: An instance of :py:class:`aidge_core.graph_view`, providing access to nodes and
+                       ordered input/output data within the computational graph.
+    :type graph_view: aidge_core.graph_view
+    :param inputs_tensor: **For future** argument to provide tensor to use in the main function, not implemented yet!
+    :type inputs_tensor: None
+    :raises RuntimeError: If there is an inconsistency in the output arguments (names, data types, sizes),
+                          indicating an internal bug in the graph representation.
+    """
+    outputs_name: list[str] = []
+    outputs_dtype: list[str] = []
+    outputs_size: list[int] = []
+    inputs_name: list[str] = []
+    gv_inputs: list[tuple[aidge_core.Node, int]] = graph_view.get_ordered_inputs()
+    gv_outputs: list[tuple[aidge_core.Node, int]] = graph_view.get_ordered_outputs()
+
+    for in_node, in_idx in gv_inputs:
+        in_node_input, in_node_input_idx = in_node.input(in_idx)
+        in_name = f"{in_node.name()}_input_{in_idx}" if in_node_input is None else f"{in_node_input.name()}_output_{in_node_input_idx}"
+        inputs_name.append(in_name)
+        input_tensor = in_node.get_operator().get_input(in_idx)
+        if input_tensor is None or input_tensor.undefined() or not input_tensor.has_impl():
+            if inputs_tensor is not None:
+                aidge_core.Log.notice("No support for inputs_tensor argument yet.")
+                aidge_core.Log.notice(f"No input tensor set for {in_name}, main generated will not be functionnal after code generation.")
+            else:
+                aidge_core.Log.notice(f"No input tensor set for {in_name}, main generated will not be functionnal after code generation.")
+        else:
+            aidge_core.export_utils.generate_input_file(export_folder=export_folder, array_name=in_name, tensor=input_tensor)
+
+    for out_node, out_id in gv_outputs:
+        outputs_name.append(f"{out_node.name()}_output_{out_id}")
+        out_tensor = out_node.get_operator().get_output(out_id)
+        outputs_dtype.append(data_conversion.aidge2c(out_tensor.dtype()))
+        outputs_size.append(out_tensor.size())
+
+    if len(outputs_name) != len(outputs_dtype) or len(outputs_name) != len(outputs_size):
+            raise RuntimeError("FATAL: Output args list does not have the same length this is an internal bug.")
+
+    ROOT = Path(__file__).resolve().parents[0]
+    generate_file(
+        str(Path(export_folder) / "main.cpp"),
+        str(ROOT / "templates" / "main_benchmark_inference_time.jinja"),
+        func_name="model_forward",
+        inputs_name=inputs_name,
+        outputs_name=outputs_name,
+        outputs_dtype=outputs_dtype,
+        outputs_size=outputs_size,
+        nb_iterations=nb_iterations,
+        nb_warmup=nb_warmup
+    )
+
+def generate_main_display_output_cpp(export_folder: str, graph_view: aidge_core.GraphView, inputs_tensor=None) -> None:
+    """
+    Generate a C++ file to manage the forward pass of a model using the given graph structure.
+
+    This function extracts details from the :py:class:`aidge_core.graph_view` object, including input and output node names, data types,
+    and tensor sizes. It uses this data to populate a C++ file template (`main.jinja`), creating a file (`main.cpp`)
+    that call the `model_forward` function, which handles data flow and processing for the exported model.
+
+    This function also generate files containing input tensor if they have been set.
+
+    :param export_folder: Path to the folder where the generated C++ file (`main.cpp`) will be saved.
+    :type export_folder: str
+    :param graph_view: An instance of :py:class:`aidge_core.graph_view`, providing access to nodes and
+                       ordered input/output data within the computational graph.
+    :type graph_view: aidge_core.graph_view
+    :param inputs_tensor: **For future** argument to provide tensor to use in the main function, not implemented yet!
+    :type inputs_tensor: None
+    :raises RuntimeError: If there is an inconsistency in the output arguments (names, data types, sizes),
+                          indicating an internal bug in the graph representation.
+    """
+    outputs_name: list[str] = []
+    outputs_dtype: list[str] = []
+    outputs_size: list[int] = []
+    inputs_name: list[str] = []
+    gv_inputs: list[tuple[aidge_core.Node, int]] = graph_view.get_ordered_inputs()
+    gv_outputs: list[tuple[aidge_core.Node, int]] = graph_view.get_ordered_outputs()
+
+    for in_node, in_idx in gv_inputs:
+        in_node_input, in_node_input_idx = in_node.input(in_idx)
+        in_name = f"{in_node.name()}_input_{in_idx}" if in_node_input is None else f"{in_node_input.name()}_output_{in_node_input_idx}"
+        inputs_name.append(in_name)
+        input_tensor = in_node.get_operator().get_input(in_idx)
+        if input_tensor is None or input_tensor.undefined() or not input_tensor.has_impl():
+            if inputs_tensor is not None:
+                aidge_core.Log.notice("No support for inputs_tensor argument yet.")
+                aidge_core.Log.notice(f"No input tensor set for {in_name}, main generated will not be functionnal after code generation.")
+            else:
+                aidge_core.Log.notice(f"No input tensor set for {in_name}, main generated will not be functionnal after code generation.")
+        else:
+            aidge_core.export_utils.generate_input_file(export_folder=export_folder, array_name=in_name, tensor=input_tensor)
+
+    for out_node, out_id in gv_outputs:
+        outputs_name.append(f"{out_node.name()}_output_{out_id}")
+        out_tensor = out_node.get_operator().get_output(out_id)
+        outputs_dtype.append(data_conversion.aidge2c(out_tensor.dtype()))
+        outputs_size.append(out_tensor.size())
+
+    if len(outputs_name) != len(outputs_dtype) or len(outputs_name) != len(outputs_size):
+            raise RuntimeError("FATAL: Output args list does not have the same length this is an internal bug.")
+
+    ROOT = Path(__file__).resolve().parents[0]
+    generate_file(
+        str(Path(export_folder) / "main.cpp"),
+        str(ROOT / "templates" / "main_benchmark_display_output.jinja"),
+        func_name="model_forward",
+        inputs_name=inputs_name,
+        outputs_name=outputs_name,
+        outputs_dtype=outputs_dtype,
+        outputs_size=outputs_size
+    )
\ No newline at end of file
diff --git a/aidge_core/export_utils/templates/main_benchmark_display_output.jinja b/aidge_core/export_utils/templates/main_benchmark_display_output.jinja
new file mode 100644
index 000000000..967a28063
--- /dev/null
+++ b/aidge_core/export_utils/templates/main_benchmark_display_output.jinja
@@ -0,0 +1,45 @@
+#include <cstddef>
+#include <cstdio>
+
+#include "forward.hpp"
+{% for name in inputs_name %}
+#include "{{ name }}.h"
+{% endfor %}
+
+{% set printf_formats = {
+    "double": "lf",
+    "float": "f",
+    "int8_t": "hhd",
+    "int16_t": "hd",
+    "int32_t": "d",
+    "int64_t": "lld",
+    "uint8_t": "hhu",
+    "uint16_t": "hu",
+    "uint32_t": "u",
+    "uint64_t": "llu"
+} %}
+
+int main()
+{
+    // Initialize the output arrays
+    {%- for o in range(outputs_name | length) %}
+    {{ outputs_dtype[o] }}* {{ outputs_name[o] }} = nullptr;
+    // {{ outputs_dtype[o] }}* results_{{ o }} = new {{ outputs_dtype[o] }}[{{ outputs_size[o] }}];
+    {% endfor %}
+
+    // Call the forward function
+    {{ func_name }}({{ inputs_name|join(", ") }}{% if inputs_name %}, {% endif %}&{{ outputs_name|join(", &") }});
+
+    // Print the results of each output
+    {%- for o in range(outputs_name | length) %}
+    for (std::size_t i = 0; i < {{ outputs_size[o] }}; ++i) {
+    {%- if outputs_dtype[o] in ["double", "float"] %}
+        std::printf("%.10{{ printf_formats[outputs_dtype[o]] }} ", {{ outputs_name[o] }}[i]);
+    {%- else %}
+        std::printf("%{{ printf_formats[outputs_dtype[o]] }} ", {{ outputs_name[o] }}[i]);
+    {%- endif %}
+    }
+    std::printf("\n");
+    {% endfor %}
+    return 0;
+}
diff --git a/aidge_core/export_utils/templates/main_benchmark_inference_time.jinja b/aidge_core/export_utils/templates/main_benchmark_inference_time.jinja
new file mode 100644
index 000000000..0943265e1
--- /dev/null
+++ b/aidge_core/export_utils/templates/main_benchmark_inference_time.jinja
@@ -0,0 +1,51 @@
+
+#include <cstdio>
+#include <ctime>
+
+#include "forward.hpp"
+{% for name in inputs_name %}
+#include "{{ name }}.h"
+{% endfor %}
+
+{% set printf_formats = {
+    "double": "%lf",
+    "float": "%f",
+    "int8_t": "%hhd",
+    "int16_t": "%hd",
+    "int32_t": "%d",
+    "int64_t": "%lld",
+    "uint8_t": "%hhu",
+    "uint16_t": "%hu",
+    "uint32_t": "%u",
+    "uint64_t": "%llu"
+} %}
+
+int main()
+{
+    // Initialize the output arrays
+    {%- for o in range(outputs_name | length) %}
+    {{ outputs_dtype[o] }}* {{ outputs_name[o] }} = nullptr;
+    {% endfor %}
+    clock_t start;
+    clock_t end;
+    double times[{{ nb_iterations }}] = {0};
+    for (std::size_t i = 0; i < {{ nb_iterations }} + {{ nb_warmup }}; ++i) {
+        if (i < {{ nb_warmup }}) {
+            {{ func_name }}({{ inputs_name|join(", ") }}{% if inputs_name %}, {% endif %}&{{ outputs_name|join(", &") }});
+        } else {
+            start = clock();
+            {{ func_name }}({{ inputs_name|join(", ") }}{% if inputs_name %}, {% endif %}&{{ outputs_name|join(", &") }});
+            {{ func_name }}({{ inputs_name|join(", ") }}{% if inputs_name %}, {% endif %}&{{ outputs_name|join(", &") }});
+            {{ func_name }}({{ inputs_name|join(", ") }}{% if inputs_name %}, {% endif %}&{{ outputs_name|join(", &") }});
+            {{ func_name }}({{ inputs_name|join(", ") }}{% if inputs_name %}, {% endif %}&{{ outputs_name|join(", &") }});
+            end = clock();
+            times[i - {{ nb_warmup }}] = ((double)(end - start)/CLOCKS_PER_SEC)/4.0;
+        }
+    }
+
+    for (std::size_t i = 0; i < {{ nb_iterations }}; ++i) {
+        printf("%.10lf ", times[i]);
+    }
+    printf("\n");
+    return 0;
+}
-- 
GitLab


From cd37e0894a11132bf0f957c48b8b4eaa8923010a Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Thu, 3 Apr 2025 09:52:24 +0000
Subject: [PATCH 3/5] update main benchmark script arguments, add warnings and
 comments about limitations and increase results comparison margins

---
 benchmark/benchmark.py       | 53 ++++++++++++++++++++++++++----------
 benchmark/benchmark_torch.py |  7 ++---
 2 files changed, 42 insertions(+), 18 deletions(-)

diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py
index 6726be68a..1061d51f5 100644
--- a/benchmark/benchmark.py
+++ b/benchmark/benchmark.py
@@ -8,10 +8,11 @@ The configuration is provided via a JSON file.
 
 import argparse
 import copy
-import importlib
+from importlib import import_module
 import json
 import os
 import sys
+from typing import Any
 
 import numpy as np
 import onnx
@@ -31,7 +32,7 @@ def load_inference_module(module_name: str):
     Exits if the module is not installed.
     """
     try:
-        module = importlib.import_module(module_name)
+        module = import_module(module_name)
         print(f"'{module_name}' module successfully imported")
         return module
     except ImportError:
@@ -73,8 +74,8 @@ def update_test_config(
     try:
         extra_attrs = other_parameters[param][str(value)]["attributes"]
     except KeyError:
-        print(
-            f"'{param}': '{value}': 'attributes' - Key not found in other_parameters. Config file may be ill-formed."
+        ai.Log.error(
+            f"Test configuration \{'{param}': '{value}'\} no 'attribute' property found. Config file may be ill-formed."
         )
         return None, None
     attributes.update(extra_attrs)
@@ -82,8 +83,8 @@ def update_test_config(
     try:
         extra_input_shapes = other_parameters[param][str(value)]["input_shapes"]
     except KeyError:
-        print(
-            f"'{param}': '{value}': 'input_shapes' - Key not found in other_parameters. Config file may be ill-formed."
+        ai.Log.error(
+            f"Test configuration \{'{param}': '{value}'\} no 'input_shapes' property found. Config file may be ill-formed."
         )
         return None, None
 
@@ -152,7 +153,8 @@ def compute_output(
 
         return benchmark_torch.compute_output(model, {v[0]: v[1] for v in input_data})
     else:
-        model = aidge_onnx.load(model=model) if "aidge" in module_name else model
+        if "aidge" in module_name:
+            model = aidge_onnx.load(model=model)
         return inference_module.benchmark.compute_output(model, input_data)
 
 
@@ -199,23 +201,29 @@ def main():
         "--time", "-t", action="store_true", help="Compute inference time"
     )
     parser.add_argument(
-        "--save-directory",
+        "--results-directory",
         type=str,
         required=True,
         help="Directory to save the results",
     )
+    parser.add_argument(
+        "--results-filename",
+        type=str,
+        required=False,
+        help="Name of the saved result file. If not provided, it will default to the '<operator_name>_<module_to_bench>.json'. If a file with that nae and at tha location already exists, it will be overrided with elements individually replaced only if new ones are computed"
+    )
     args = parser.parse_args()
 
     compare_mode = args.compare_with_onnxruntime
     time_mode = args.time
     module_name = args.module_to_bench
-    save_directory = args.save_directory
+    results_directory = args.results_directory
 
     # Load the inference module
     inference_module = load_inference_module(module_name)
 
     # Configure aidge logging
-    ai.Log.set_console_level(ai.Level.Error)
+    ai.Log.set_console_level(ai.Level.Warn)
     ai.Log.set_precision(10)
 
     # Load configuration
@@ -224,6 +232,14 @@ def main():
     opset_version: int = config["opset_version"]
     initializer_rank: int = config.get("initializer_rank", 1)
 
+    test_meta_data: dict[str, Any] = config["test_meta_data"]
+    if test_meta_data["multiple_batchs"] == True and "export" in module_name:
+        ai.Log.warn("The tested module seems to be an export module and your test cases contains "
+            "\033[31;1;multiple\033[0m batchs inputs. This could lead to inaccurate results due to "
+            "the stream-based (single batch) nature of exports implementations, or an error during "
+            "export the 'export generation' step. Unless you know what you are doing, you should "
+            "probably change your configuration file for single batch tests.")
+
     base_input_shapes: list[str, list[int]] = config["base_configuration"][
         "input_shapes"
     ]
@@ -252,8 +268,8 @@ def main():
 
     # Initialize or load existing benchmark results
     results = {"library": "", "compare": {}, "time": {}}
-    filename: str = f"{operator_aidge.lower()}_{module_name}.json"
-    results_file_path = os.path.join(save_directory, filename)
+    filename: str = (args.results_filename + ".json") if args.results_filename else f"{operator_aidge.lower()}_{module_name}.json"
+    results_file_path = os.path.join(results_directory, filename)
     # results_file_path = get_results_file_path(module_name, operator_aidge, save_directory)
     if os.path.exists(results_file_path):
         with open(results_file_path, "r") as f:
@@ -270,6 +286,13 @@ def main():
 
         for value in test_values:
             print(f"â–· {param} -- {value}")
+            try:
+                other_parameters[param][str(value)]
+            except KeyError:
+                ai.Log.error(
+                    f"Test configuration {{'{param}': '{value}'}} not found. Config file may be ill-formed."
+                )
+                continue
             updated_attrs, updated_input_shapes = update_test_config(
                 param,
                 value,
@@ -304,12 +327,14 @@ def main():
                 print(f" └┬─Assessing results are the same as 'onnxruntime' module...")
                 ref = compute_output("onnxruntime", model, input_data, None)
                 tested = compute_output(module_name, model, input_data, inference_module)
-
+                ai.Log.info(f"ref: {ref}\n")
+                ai.Log.info(f"tested: {tested}\n")
                 if len(ref) > 1:
                     print("Multi-output comparison not handled yet")
+                    print([i.shape for i in ref])
                     sys.exit(1)
                 results["compare"][param][str(value)] = bool(
-                    np.all(np.isclose(ref, tested))
+                    np.all(np.isclose(ref, tested, rtol=1e-3, atol=1e-5))
                 )
                 print(
                     f"  └─[ {'o' if results['compare'][param][str(value)] else 'x'} ]"
diff --git a/benchmark/benchmark_torch.py b/benchmark/benchmark_torch.py
index 0275daa01..abe92ffe2 100644
--- a/benchmark/benchmark_torch.py
+++ b/benchmark/benchmark_torch.py
@@ -26,7 +26,6 @@ def measure_inference_time(model_onnx: ModelProto, input_data: dict[str, np.ndar
 
     inputs = [torch.tensor(v, device=device) for _, v in input_data.items()]
     timings = []
-
     with torch.no_grad():
         # Warm-up runs
         for i in range(nb_warmup + nb_iterations):
@@ -59,7 +58,7 @@ def compute_output(model_onnx: ModelProto, input_data: dict[str, np.ndarray]) ->
     inputs = [torch.tensor(v, device=device) for _, v in input_data.items()]
 
     with torch.no_grad():
-        outputs = model(*inputs)
-    outputs = [o.numpy() if isinstance(o, torch.Tensor) else np.array(o) for o in outputs]
+        # Warning: not tested for multiple outputs case
+        output = model(*inputs)
 
-    return outputs
+    return output.numpy()
-- 
GitLab


From f537a2f8ec5b48564d3bed4e30f29ef32733d15c Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Thu, 3 Apr 2025 12:04:53 +0000
Subject: [PATCH 4/5] add: test config files for some operators (element wise,
 fc, conv, concat)

---
 benchmark/operator_config/add_config.json    | 160 +++++++++++++++++++
 benchmark/operator_config/concat_config.json |  86 ++++++++++
 benchmark/operator_config/config.json        |  64 ++++++++
 benchmark/operator_config/conv2d_config.json | 112 +++++++++++++
 benchmark/operator_config/div_onfig.json     | 160 +++++++++++++++++++
 benchmark/operator_config/fc_config.json     |  49 ++++++
 benchmark/operator_config/mul_config.json    | 160 +++++++++++++++++++
 benchmark/operator_config/relu_config.json   |  59 +++++++
 benchmark/operator_config/sub_config.json    | 160 +++++++++++++++++++
 9 files changed, 1010 insertions(+)
 create mode 100644 benchmark/operator_config/add_config.json
 create mode 100644 benchmark/operator_config/concat_config.json
 create mode 100644 benchmark/operator_config/config.json
 create mode 100644 benchmark/operator_config/conv2d_config.json
 create mode 100644 benchmark/operator_config/div_onfig.json
 create mode 100644 benchmark/operator_config/fc_config.json
 create mode 100644 benchmark/operator_config/mul_config.json
 create mode 100644 benchmark/operator_config/relu_config.json
 create mode 100644 benchmark/operator_config/sub_config.json

diff --git a/benchmark/operator_config/add_config.json b/benchmark/operator_config/add_config.json
new file mode 100644
index 000000000..b0f9424d3
--- /dev/null
+++ b/benchmark/operator_config/add_config.json
@@ -0,0 +1,160 @@
+{
+	"operator": "Add",
+	"opset_version": 21,
+	"initializer_rank": 2,
+    "test_meta_data": {
+        "multiple_batchs": true
+    },
+	"base_configuration": {
+		"input_shapes": [
+			["input_0", [64, 64, 64, 64]],
+			["input_1", [64, 64, 64, 64]]
+		],
+		"attributes": {}
+	},
+	"test_configuration": {
+		"main_parameters": {
+			"dim size": [
+				1,4,16,64,128
+			],
+            "one dim broadcasted (idx)": [
+				0, 1, 2, 3
+			],
+            "two dims broadcasted (idx)": [
+				[0,1], [0,2], [0,3], [1,2], [1,3], [2,3]
+			],
+            "nb missing axis 1st input": [
+                1, 2, 3, 4
+            ]
+		},
+		"other_parameters": {
+			"dim size": {
+                "1": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 1, 1, 1]],
+                        ["input_1", [1, 1, 1, 1]]
+					]
+				},
+				"4": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [4, 4, 4, 4]],
+                        ["input_1", [4, 4, 4, 4]]
+					]
+				},
+				"16": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [16, 16, 16, 16]],
+                        ["input_1", [16, 16, 16, 16]]
+					]
+				},
+                "64": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64, 64, 64]],
+                        ["input_1", [64, 64, 64, 64]]
+					]
+				},
+                "128": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [128, 128, 128, 128]],
+                        ["input_1", [128, 128, 128, 128]]
+					]
+				}
+			},
+            "one dim broadcasted (idx)": {
+                "0": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 64, 64, 64]]
+					]
+				},
+                "1": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 1, 64, 64]]
+					]
+				},
+                "2": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64, 1, 64]]
+					]
+				},
+                "3": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64, 64, 1]]
+					]
+				}
+            },
+            "two dims broadcasted (idx)": {
+                "[0, 1]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 1, 64, 64]]
+					]
+				},
+                "[0, 2]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 64, 1, 64]]
+					]
+				},
+                "[0, 3]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 64, 64, 1]]
+					]
+				},
+                "[1, 2]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 1, 1, 64]]
+					]
+				},
+                "[1, 3]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 1, 64, 1]]
+					]
+				},
+                "[2, 3]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64, 1, 1]]
+					]
+				}
+            },
+            "nb missing axis 1st input": {
+				"1": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64, 64]]
+					]
+				},
+				"2": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64]]
+					]
+				},
+                "3": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64]]
+					]
+				},
+                "4": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", []]
+					]
+				}
+			}
+		}
+	}
+}
\ No newline at end of file
diff --git a/benchmark/operator_config/concat_config.json b/benchmark/operator_config/concat_config.json
new file mode 100644
index 000000000..344d77a7c
--- /dev/null
+++ b/benchmark/operator_config/concat_config.json
@@ -0,0 +1,86 @@
+{
+    "operator": "Concat",
+    "opset_version": 13,
+    "initializer_rank": 3,
+    "test_meta_data": {
+        "multiple_batchs": false
+    },
+	"base_configuration": {
+		"input_shapes": [
+			["input_0", [1, 64, 64, 64]],
+			["input_1", [1, 64, 64, 64]],
+			["input_2", [1, 64, 64, 64]]
+		],
+		"attributes": {
+			"axis": 1
+		}
+	},
+	"test_configuration": {
+		"main_parameters": {
+			"axis": [
+				0,1,2,3
+			],
+			"dims size": [
+				1, 4, 16, 64, 128
+			]
+		},
+		"other_parameters": {
+			"axis": {
+				"0": {
+                    "attributes": {},
+                    "input_shapes": []
+                },
+				"1": {
+                    "attributes": {},
+                    "input_shapes": []
+                },
+				"2": {
+                    "attributes": {},
+                    "input_shapes": []
+                },
+				"3": {
+                    "attributes": {},
+                    "input_shapes": []
+                }
+			},
+			"dims size": {
+				"1": {
+					"attributes": {},
+					"input_shapes": [
+                        ["input_0", [1, 1, 1, 1]],
+                        ["input_1", [1, 1, 1, 1]],
+                        ["input_2", [1, 1, 1, 1]]
+                    ]
+				},
+				"4": {
+					"attributes": {},
+					"input_shapes": [
+                        ["input_0", [1, 4, 4, 4]],
+                        ["input_1", [1, 4, 4, 4]],
+                        ["input_2", [1, 4, 4, 4]]
+                    ]
+				},
+				"16": {
+					"attributes": {},
+					"input_shapes": [
+                        ["input_0", [1, 16, 16, 16]],
+                        ["input_1", [1, 16, 16, 16]],
+                        ["input_2", [1, 16, 16, 16]]
+                    ]
+				},
+                "64": {
+                    "attributes": {},
+                    "input_shapes": []
+                },
+                "256": {
+                    "attributes": {},
+                    "input_shapes": [
+                        ["input_0", [1, 128, 128, 128]],
+                        ["input_1", [1, 128, 128, 128]],
+                        ["input_2", [1, 128, 128, 128]]
+                    ]
+                }
+			}
+		}
+	}
+}
\ No newline at end of file
diff --git a/benchmark/operator_config/config.json b/benchmark/operator_config/config.json
new file mode 100644
index 000000000..cd471787e
--- /dev/null
+++ b/benchmark/operator_config/config.json
@@ -0,0 +1,64 @@
+{
+	"operator": "onnx_operator_type",
+	"opset_version": 21,
+	"initializer_rank": 2,
+    "test_meta_data": {
+        "multiple_batchs": false
+    },
+	"base_configuration": {
+		"input_shapes": [
+			["input_0", [1, 10, 10, 10]],
+			["input_1", [1, 10, 10, 10]]
+		],
+		"attributes": {
+			"attributes_name_0": 0,
+			"attributes_name_1": 0
+		}
+	},
+	"test_configuration": {
+		"main_parameters": {
+			"tested_param_0": [
+				1, 2, 3
+			],
+			"tested_param_1": [
+				[1, 1],
+				[3, 3],
+                [5, 5]
+			]
+		},
+		"other_parameters": {
+			"tested_param_0": {
+				"1": {
+					"attributes": {
+                        "attributes_name_0": 1
+                    },
+					"input_shapes": [
+						["input_0", [1, 10, 10, 10]]
+					]
+				},
+				"2": {},
+				"3": {}
+			},
+			"tested_param_1": {
+				"[1, 1]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_1", [1, 10, 1, 1]]
+					]
+				},
+				"[3, 3]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_1", [1, 10, 3, 3]]
+					]
+				},
+				"[5, 5]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_1", [1, 10, 5, 5]]
+					]
+				}
+			}
+		}
+	}
+}
\ No newline at end of file
diff --git a/benchmark/operator_config/conv2d_config.json b/benchmark/operator_config/conv2d_config.json
new file mode 100644
index 000000000..bf79dc90e
--- /dev/null
+++ b/benchmark/operator_config/conv2d_config.json
@@ -0,0 +1,112 @@
+{
+	"operator": "Conv",
+	"opset_version": 21,
+	"initializer_rank": 1,
+    "test_meta_data": {
+        "multiple_batchs": false
+    },
+	"base_configuration": {
+		"input_shapes": [
+			["input_0", [1, 10, 200, 200]],
+			["weight_1", [10, 10, 3, 3]],
+			["bias_2", [10]]
+		],
+		"attributes": {
+			"kernel_shape": [3, 3],
+			"strides": [1, 1],
+			"dilations": [1, 1]
+		}
+	},
+	"test_configuration": {
+		"main_parameters": {
+			"feature_map_size": [
+				10,100,500
+			],
+			"kernel_shape": [
+				[1, 1],
+				[3, 3],
+                [5, 5]
+			],
+			"strides": [
+				[1, 1],
+				[2, 2],
+				[3, 3]
+			],
+			"dilations": [
+				[1, 1],
+				[2, 2],
+				[3, 3]
+			]
+		},
+		"other_parameters": {
+			"feature_map_size": {
+				"10": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 10, 10, 10]]
+					]
+				},
+				"100": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 10, 100, 100]]
+					]
+				},
+				"500": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 10, 500, 500]]
+					]
+				}
+			},
+			"kernel_shape": {
+				"[1, 1]": {
+					"attributes": {},
+					"input_shapes": [
+						["weight_1", [10, 10, 1, 1]]
+					]
+				},
+				"[3, 3]": {
+					"attributes": {},
+					"input_shapes": [
+						["weight_1", [10, 10, 3, 3]]
+					]
+				},
+				"[5, 5]": {
+					"attributes": {},
+					"input_shapes": [
+						["weight_1", [10, 10, 5, 5]]
+					]
+				}
+			},
+			"strides": {
+				"[1, 1]": {
+					"attributes": {},
+					"input_shapes": []
+				},
+				"[2, 2]": {
+					"attributes": {},
+					"input_shapes": []
+				},
+				"[3, 3]": {
+					"attributes": {},
+					"input_shapes": []
+				}
+			},
+			"dilations": {
+				"[1, 1]": {
+					"attributes": {},
+					"input_shapes": []
+				},
+				"[2, 2]": {
+					"attributes": {},
+					"input_shapes": []
+				},
+				"[3, 3]": {
+					"attributes": {},
+					"input_shapes": []
+				}
+			}
+		}
+	}
+}
\ No newline at end of file
diff --git a/benchmark/operator_config/div_onfig.json b/benchmark/operator_config/div_onfig.json
new file mode 100644
index 000000000..446715ad6
--- /dev/null
+++ b/benchmark/operator_config/div_onfig.json
@@ -0,0 +1,160 @@
+{
+	"operator": "Div",
+	"opset_version": 21,
+	"initializer_rank": 2,
+    "test_meta_data": {
+        "multiple_batchs": true
+    },
+	"base_configuration": {
+		"input_shapes": [
+			["input_0", [64, 64, 64, 64]],
+			["input_1", [64, 64, 64, 64]]
+		],
+		"attributes": {}
+	},
+	"test_configuration": {
+		"main_parameters": {
+			"dim size": [
+				1,4,16,64,128
+			],
+            "one dim broadcasted (idx)": [
+				0, 1, 2, 3
+			],
+            "two dims broadcasted (idx)": [
+				[0,1], [0,2], [0,3], [1,2], [1,3], [2,3]
+			],
+            "nb missing axis 1st input": [
+                1, 2, 3, 4
+            ]
+		},
+		"other_parameters": {
+			"dim size": {
+                "1": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 1, 1, 1]],
+                        ["input_1", [1, 1, 1, 1]]
+					]
+				},
+				"4": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [4, 4, 4, 4]],
+                        ["input_1", [4, 4, 4, 4]]
+					]
+				},
+				"16": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [16, 16, 16, 16]],
+                        ["input_1", [16, 16, 16, 16]]
+					]
+				},
+                "64": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64, 64, 64]],
+                        ["input_1", [64, 64, 64, 64]]
+					]
+				},
+                "128": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [128, 128, 128, 128]],
+                        ["input_1", [128, 128, 128, 128]]
+					]
+				}
+			},
+            "one dim broadcasted (idx)": {
+                "0": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 64, 64, 64]]
+					]
+				},
+                "1": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 1, 64, 64]]
+					]
+				},
+                "2": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64, 1, 64]]
+					]
+				},
+                "3": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64, 64, 1]]
+					]
+				}
+            },
+            "two dims broadcasted (idx)": {
+                "[0, 1]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 1, 64, 64]]
+					]
+				},
+                "[0, 2]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 64, 1, 64]]
+					]
+				},
+                "[0, 3]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 64, 64, 1]]
+					]
+				},
+                "[1, 2]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 1, 1, 64]]
+					]
+				},
+                "[1, 3]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 1, 64, 1]]
+					]
+				},
+                "[2, 3]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64, 1, 1]]
+					]
+				}
+            },
+            "nb missing axis 1st input": {
+				"1": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64, 64]]
+					]
+				},
+				"2": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64]]
+					]
+				},
+                "3": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64]]
+					]
+				},
+                "4": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", []]
+					]
+				}
+			}
+		}
+	}
+}
\ No newline at end of file
diff --git a/benchmark/operator_config/fc_config.json b/benchmark/operator_config/fc_config.json
new file mode 100644
index 000000000..6a401c84a
--- /dev/null
+++ b/benchmark/operator_config/fc_config.json
@@ -0,0 +1,49 @@
+{
+	"operator": "Gemm",
+	"opset_version": 21,
+	"initializer_rank": 1,
+    "test_meta_data": {
+        "multiple_batchs": false
+    },
+	"base_configuration": {
+		"input_shapes": [
+			["input_0", [1, 100]],
+			["weight_1", [100, 50]],
+            ["bias_2", [50]]
+		],
+		"attributes": {
+		}
+	},
+	"test_configuration": {
+		"main_parameters": {
+			"input_size_0": [
+				10, 100, 1000
+			]
+		},
+		"other_parameters": {
+			"input_size_0": {
+				"10": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 10]],
+                        ["weight_1", [10, 50]]
+					]
+				},
+				"100": {
+                    "attributes": {},
+					"input_shapes": [
+						["input_0", [1, 100]],
+                        ["weight_1", [100, 50]]
+					]
+                },
+				"1000": {
+                    "attributes": {},
+					"input_shapes": [
+						["input_0", [1, 1000]],
+                        ["weight_1", [1000, 50]]
+					]
+                }
+			}
+		}
+	}
+}
\ No newline at end of file
diff --git a/benchmark/operator_config/mul_config.json b/benchmark/operator_config/mul_config.json
new file mode 100644
index 000000000..3c89d5671
--- /dev/null
+++ b/benchmark/operator_config/mul_config.json
@@ -0,0 +1,160 @@
+{
+	"operator": "Mul",
+	"opset_version": 21,
+	"initializer_rank": 2,
+    "test_meta_data": {
+        "multiple_batchs": true
+    },
+	"base_configuration": {
+		"input_shapes": [
+			["input_0", [64, 64, 64, 64]],
+			["input_1", [64, 64, 64, 64]]
+		],
+		"attributes": {}
+	},
+	"test_configuration": {
+		"main_parameters": {
+			"dim size": [
+				1,4,16,64,128
+			],
+            "one dim broadcasted (idx)": [
+				0, 1, 2, 3
+			],
+            "two dims broadcasted (idx)": [
+				[0,1], [0,2], [0,3], [1,2], [1,3], [2,3]
+			],
+            "nb missing axis 1st input": [
+                1, 2, 3, 4
+            ]
+		},
+		"other_parameters": {
+			"dim size": {
+                "1": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 1, 1, 1]],
+                        ["input_1", [1, 1, 1, 1]]
+					]
+				},
+				"4": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [4, 4, 4, 4]],
+                        ["input_1", [4, 4, 4, 4]]
+					]
+				},
+				"16": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [16, 16, 16, 16]],
+                        ["input_1", [16, 16, 16, 16]]
+					]
+				},
+                "64": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64, 64, 64]],
+                        ["input_1", [64, 64, 64, 64]]
+					]
+				},
+                "128": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [128, 128, 128, 128]],
+                        ["input_1", [128, 128, 128, 128]]
+					]
+				}
+			},
+            "one dim broadcasted (idx)": {
+                "0": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 64, 64, 64]]
+					]
+				},
+                "1": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 1, 64, 64]]
+					]
+				},
+                "2": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64, 1, 64]]
+					]
+				},
+                "3": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64, 64, 1]]
+					]
+				}
+            },
+            "two dims broadcasted (idx)": {
+                "[0, 1]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 1, 64, 64]]
+					]
+				},
+                "[0, 2]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 64, 1, 64]]
+					]
+				},
+                "[0, 3]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 64, 64, 1]]
+					]
+				},
+                "[1, 2]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 1, 1, 64]]
+					]
+				},
+                "[1, 3]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 1, 64, 1]]
+					]
+				},
+                "[2, 3]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64, 1, 1]]
+					]
+				}
+            },
+            "nb missing axis 1st input": {
+				"1": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64, 64]]
+					]
+				},
+				"2": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64]]
+					]
+				},
+                "3": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64]]
+					]
+				},
+                "4": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", []]
+					]
+				}
+			}
+		}
+	}
+}
\ No newline at end of file
diff --git a/benchmark/operator_config/relu_config.json b/benchmark/operator_config/relu_config.json
new file mode 100644
index 000000000..f97b85f5f
--- /dev/null
+++ b/benchmark/operator_config/relu_config.json
@@ -0,0 +1,59 @@
+{
+    "operator": "Relu",
+    "opset_version": 14,
+    "initializer_rank": 1,
+    "test_meta_data": {
+        "multiple_batchs": true
+    },
+	"base_configuration": {
+		"input_shapes": [
+			["input_0", [64, 64, 64, 64]]
+		],
+		"attributes": {}
+	},
+	"test_configuration": {
+		"main_parameters": {
+			"dims size": [
+				1, 4, 16, 32, 64, 128
+			]
+		},
+		"other_parameters": {
+			"dims size": {
+				"1": {
+					"attributes": {},
+					"input_shapes": [
+                        ["input_0", [1, 1, 1, 1]]
+                    ]
+				},
+				"4": {
+					"attributes": {},
+					"input_shapes": [
+                        ["input_0", [4, 4, 4, 4]]
+                    ]
+				},
+				"16": {
+					"attributes": {},
+					"input_shapes": [
+                        ["input_0", [16, 16, 16, 16]]
+                    ]
+				},
+                "32": {
+					"attributes": {},
+					"input_shapes": [
+                        ["input_0", [32, 32, 32, 32]]
+                    ]
+				},
+                "64": {
+                    "attributes": {},
+                    "input_shapes": []
+                },
+                "128": {
+                    "attributes": {},
+                    "input_shapes": [
+                        ["input_0", [128, 128, 128, 128]]
+                    ]
+                }
+			}
+		}
+	}
+}
\ No newline at end of file
diff --git a/benchmark/operator_config/sub_config.json b/benchmark/operator_config/sub_config.json
new file mode 100644
index 000000000..17359793a
--- /dev/null
+++ b/benchmark/operator_config/sub_config.json
@@ -0,0 +1,160 @@
+{
+	"operator": "Sub",
+	"opset_version": 21,
+	"initializer_rank": 2,
+    "test_meta_data": {
+        "multiple_batchs": true
+    },
+	"base_configuration": {
+		"input_shapes": [
+			["input_0", [64, 64, 64, 64]],
+			["input_1", [64, 64, 64, 64]]
+		],
+		"attributes": {}
+	},
+	"test_configuration": {
+		"main_parameters": {
+			"dim size": [
+				1,4,16,64,128
+			],
+            "one dim broadcasted (idx)": [
+				0, 1, 2, 3
+			],
+            "two dims broadcasted (idx)": [
+				[0,1], [0,2], [0,3], [1,2], [1,3], [2,3]
+			],
+            "nb missing axis 1st input": [
+                1, 2, 3, 4
+            ]
+		},
+		"other_parameters": {
+			"dim size": {
+                "1": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 1, 1, 1]],
+                        ["input_1", [1, 1, 1, 1]]
+					]
+				},
+				"4": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [4, 4, 4, 4]],
+                        ["input_1", [4, 4, 4, 4]]
+					]
+				},
+				"16": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [16, 16, 16, 16]],
+                        ["input_1", [16, 16, 16, 16]]
+					]
+				},
+                "64": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64, 64, 64]],
+                        ["input_1", [64, 64, 64, 64]]
+					]
+				},
+                "128": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [128, 128, 128, 128]],
+                        ["input_1", [128, 128, 128, 128]]
+					]
+				}
+			},
+            "one dim broadcasted (idx)": {
+                "0": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 64, 64, 64]]
+					]
+				},
+                "1": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 1, 64, 64]]
+					]
+				},
+                "2": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64, 1, 64]]
+					]
+				},
+                "3": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64, 64, 1]]
+					]
+				}
+            },
+            "two dims broadcasted (idx)": {
+                "[0, 1]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 1, 64, 64]]
+					]
+				},
+                "[0, 2]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 64, 1, 64]]
+					]
+				},
+                "[0, 3]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [1, 64, 64, 1]]
+					]
+				},
+                "[1, 2]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 1, 1, 64]]
+					]
+				},
+                "[1, 3]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 1, 64, 1]]
+					]
+				},
+                "[2, 3]": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64, 1, 1]]
+					]
+				}
+            },
+            "nb missing axis 1st input": {
+				"1": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64, 64]]
+					]
+				},
+				"2": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64, 64]]
+					]
+				},
+                "3": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", [64]]
+					]
+				},
+                "4": {
+					"attributes": {},
+					"input_shapes": [
+						["input_0", []]
+					]
+				}
+			}
+		}
+	}
+}
\ No newline at end of file
-- 
GitLab


From 3adc3fd7283689d05a36efa289f98236f2a13722 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Thu, 3 Apr 2025 12:05:30 +0000
Subject: [PATCH 5/5] change plot color

---
 benchmark/generate_graph.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/benchmark/generate_graph.py b/benchmark/generate_graph.py
index 720755225..82f61d5c8 100644
--- a/benchmark/generate_graph.py
+++ b/benchmark/generate_graph.py
@@ -132,7 +132,7 @@ def create_relative_difference_plots(test_parameters: dict, ref_times: dict, ref
             x=f'{param_name}',
             y='Ratio',
             hue='Library',
-            palette='viridis',
+            palette='hls',
             ax=ax,
             errorbar=None
         )
-- 
GitLab