Skip to content
Snippets Groups Projects
Commit 844fe303 authored by Iryna DE ALBUQUERQUE SILVA's avatar Iryna DE ALBUQUERQUE SILVA
Browse files

Added checks for validity of GraphView and JSON path inputs

parent cb4ce83f
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!211Add show_graphview funcionality.
...@@ -3,6 +3,7 @@ import json ...@@ -3,6 +3,7 @@ import json
import builtins import builtins
import aidge_core import aidge_core
import numpy as np import numpy as np
from pathlib import Path
def _retrieve_operator_attrs(node : aidge_core.Node) -> dict[str, int, float, bool, None]: def _retrieve_operator_attrs(node : aidge_core.Node) -> dict[str, int, float, bool, None]:
""" """
...@@ -192,22 +193,31 @@ def gview_to_json(gview : aidge_core.GraphView, json_path : str, write_trainable ...@@ -192,22 +193,31 @@ def gview_to_json(gview : aidge_core.GraphView, json_path : str, write_trainable
:type params_file_format: str, optional :type params_file_format: str, optional
""" """
if json_path.is_dir():
json_path = (json_path.parent).joinpath('model.json')
elif not json_path.is_dir():
if json_path.suffix == '.json':
pass
else:
raise Exception('If ``json_path`` contains a filename it must be of JSON format.')
if write_trainable_params_ext: if write_trainable_params_ext:
dir_name, fname_ext = os.path.split(json_path) path_trainable_params = (json_path.parent).joinpath(json_path.stem + '_trainable_params/')
fname = os.path.splitext(fname_ext)[0] + '_trainable_params'
path_trainable_params = os.path.join(dir_name, fname)
os.mkdir(path_trainable_params)
else: else:
path_trainable_params = '' path_trainable_params = ''
# Sort GraphView in topological order if isinstance(gview, aidge_core.GraphView):
ordered_nodes = gview.get_ordered_nodes() # Sort GraphView in topological order
ordered_nodes = gview.get_ordered_nodes()
# Create dict from GraphView
graphview_dict = _create_dict(ordered_nodes, write_trainable_params_embed, write_trainable_params_ext, path_trainable_params, params_file_format)
# Write dict to JSON # Create dict from GraphView
_write_dict_json(graphview_dict, json_path) graphview_dict = _create_dict(ordered_nodes, write_trainable_params_embed, write_trainable_params_ext, path_trainable_params, params_file_format)
# Write dict to JSON
_write_dict_json(graphview_dict, json_path)
else:
raise Exception("Graph must be an aidge_core.GraphView instance.")
return None return None
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment