From 844fe303cd6914bd1becae64298d4a12cc128816 Mon Sep 17 00:00:00 2001 From: idealbuq <iryna.dealbuquerquesilva@cea.fr> Date: Wed, 2 Oct 2024 07:48:26 +0000 Subject: [PATCH] Added checks for validity of GraphView and JSON path inputs --- aidge_core/show_graphview.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/aidge_core/show_graphview.py b/aidge_core/show_graphview.py index dfc0acc45..ee7908b12 100644 --- a/aidge_core/show_graphview.py +++ b/aidge_core/show_graphview.py @@ -3,6 +3,7 @@ import json import builtins import aidge_core import numpy as np +from pathlib import Path def _retrieve_operator_attrs(node : aidge_core.Node) -> dict[str, int, float, bool, None]: """ @@ -192,22 +193,31 @@ def gview_to_json(gview : aidge_core.GraphView, json_path : str, write_trainable :type params_file_format: str, optional """ + if json_path.is_dir(): + json_path = (json_path.parent).joinpath('model.json') + + elif not json_path.is_dir(): + if json_path.suffix == '.json': + pass + else: + raise Exception('If ``json_path`` contains a filename it must be of JSON format.') if write_trainable_params_ext: - dir_name, fname_ext = os.path.split(json_path) - fname = os.path.splitext(fname_ext)[0] + '_trainable_params' - path_trainable_params = os.path.join(dir_name, fname) - os.mkdir(path_trainable_params) + path_trainable_params = (json_path.parent).joinpath(json_path.stem + '_trainable_params/') else: path_trainable_params = '' - # Sort GraphView in topological order - ordered_nodes = gview.get_ordered_nodes() - - # Create dict from GraphView - graphview_dict = _create_dict(ordered_nodes, write_trainable_params_embed, write_trainable_params_ext, path_trainable_params, params_file_format) + if isinstance(gview, aidge_core.GraphView): + # Sort GraphView in topological order + ordered_nodes = gview.get_ordered_nodes() - # Write dict to JSON - _write_dict_json(graphview_dict, json_path) + # Create dict from GraphView + graphview_dict = _create_dict(ordered_nodes, write_trainable_params_embed, write_trainable_params_ext, path_trainable_params, params_file_format) + + # Write dict to JSON + _write_dict_json(graphview_dict, json_path) + + else: + raise Exception("Graph must be an aidge_core.GraphView instance.") return None \ No newline at end of file -- GitLab