From 844fe303cd6914bd1becae64298d4a12cc128816 Mon Sep 17 00:00:00 2001
From: idealbuq <iryna.dealbuquerquesilva@cea.fr>
Date: Wed, 2 Oct 2024 07:48:26 +0000
Subject: [PATCH] Added checks for validity of GraphView and JSON path inputs

---
 aidge_core/show_graphview.py | 32 +++++++++++++++++++++-----------
 1 file changed, 21 insertions(+), 11 deletions(-)

diff --git a/aidge_core/show_graphview.py b/aidge_core/show_graphview.py
index dfc0acc45..ee7908b12 100644
--- a/aidge_core/show_graphview.py
+++ b/aidge_core/show_graphview.py
@@ -3,6 +3,7 @@ import json
 import builtins
 import aidge_core
 import numpy as np
+from pathlib import Path
  
 def _retrieve_operator_attrs(node : aidge_core.Node) -> dict[str, int, float, bool, None]:
     """
@@ -192,22 +193,31 @@ def gview_to_json(gview : aidge_core.GraphView, json_path : str, write_trainable
     :type params_file_format: str, optional
     """
 
+    if json_path.is_dir():
+        json_path = (json_path.parent).joinpath('model.json')
+
+    elif not json_path.is_dir():
+        if json_path.suffix == '.json':
+            pass
+        else: 
+            raise Exception('If ``json_path`` contains a filename it must be of JSON format.')
 
     if write_trainable_params_ext:
-        dir_name, fname_ext = os.path.split(json_path)
-        fname = os.path.splitext(fname_ext)[0] + '_trainable_params'
-        path_trainable_params = os.path.join(dir_name, fname) 
-        os.mkdir(path_trainable_params)
+        path_trainable_params = (json_path.parent).joinpath(json_path.stem +  '_trainable_params/')
     else:
         path_trainable_params = ''
 
-    # Sort GraphView in topological order
-    ordered_nodes = gview.get_ordered_nodes()
-   
-    # Create dict from GraphView 
-    graphview_dict = _create_dict(ordered_nodes, write_trainable_params_embed, write_trainable_params_ext, path_trainable_params, params_file_format)
+    if isinstance(gview, aidge_core.GraphView):
+        # Sort GraphView in topological order
+        ordered_nodes = gview.get_ordered_nodes()
     
-    # Write dict to JSON
-    _write_dict_json(graphview_dict, json_path)
+        # Create dict from GraphView 
+        graphview_dict = _create_dict(ordered_nodes, write_trainable_params_embed, write_trainable_params_ext, path_trainable_params, params_file_format)
+        
+        # Write dict to JSON
+        _write_dict_json(graphview_dict, json_path)
+
+    else:
+        raise Exception("Graph must be an aidge_core.GraphView instance.")
         
     return None
\ No newline at end of file
-- 
GitLab