From cb4ce83ff1eff145085dab94537ec5006817aa53 Mon Sep 17 00:00:00 2001
From: idealbuq <iryna.dealbuquerquesilva@cea.fr>
Date: Tue, 1 Oct 2024 09:51:45 +0000
Subject: [PATCH] Added exception for non recognized file format for writing
 trainable parameters in external file

---
 aidge_core/show_graphview.py | 58 ++++++++++++++++++++----------------
 1 file changed, 33 insertions(+), 25 deletions(-)

diff --git a/aidge_core/show_graphview.py b/aidge_core/show_graphview.py
index 94669110a..dfc0acc45 100644
--- a/aidge_core/show_graphview.py
+++ b/aidge_core/show_graphview.py
@@ -104,14 +104,14 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e
             # Check if my node is a metaop
             attributes_dict = {}
             if isinstance(node.get_operator(), aidge_core.MetaOperator_Op):
-                    attributes_dict['micro_graph'] = []
-                    for micro_node in node.get_operator().get_micro_graph().get_nodes():
-                        micro_node_dict = {'name' : micro_node.name(), 
-                                           'optype' : micro_node.type()}
-                        
-                        micro_node_attr_dict =  _retrieve_operator_attrs(micro_node)
-                        micro_node_dict['attributes'] = micro_node_attr_dict
-                        attributes_dict['micro_graph'].append(micro_node_dict)
+                attributes_dict['micro_graph'] = []
+                for micro_node in node.get_operator().get_micro_graph().get_nodes():
+                    micro_node_dict = {'name' : micro_node.name(), 
+                                        'optype' : micro_node.type()}
+                    
+                    micro_node_attr_dict =  _retrieve_operator_attrs(micro_node)
+                    micro_node_dict['attributes'] = micro_node_attr_dict
+                    attributes_dict['micro_graph'].append(micro_node_dict)
 
             else:
                 node_attr_dict = _retrieve_operator_attrs(node)
@@ -120,26 +120,34 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e
             node_dict['attributes'] = attributes_dict
 
             if node.type() == 'Producer':
-                if write_trainable_params_ext and params_file_format=='npz':
-                    np.savez_compressed(os.path.join(path_trainable_params, node.name()), **{node.name() : node.get_operator().get_output(0)})
-                    node_dict['tensor_data'] = os.path.join(path_trainable_params, node.name() + '.npz')
-
-                elif write_trainable_params_ext and params_file_format=='json':
-                    tensor = np.array(node.get_operator().get_output(0))
-                    tensor_dict = {
-                        node.name() : 
-                        {
-                            'dims' : tensor.shape,
-                            'data_type' : str(tensor.dtype),
-                            'tensor_data' : tensor.tolist()
-                        }   
-                    }
+                if write_trainable_params_ext:
+                    
+                    params_file_format.casefold()
+
+                    if params_file_format=='npz':
+                        np.savez_compressed(os.path.join(path_trainable_params, node.name()), **{node.name() : node.get_operator().get_output(0)})
+                        node_dict['tensor_data'] = os.path.join(path_trainable_params, node.name() + '.npz')
+
+                    elif params_file_format=='json':
+                        tensor = np.array(node.get_operator().get_output(0))
+                        tensor_dict = {
+                            node.name() : 
+                            {
+                                'dims' : tensor.shape,
+                                'data_type' : str(tensor.dtype),
+                                'tensor_data' : tensor.tolist()
+                            }   
+                        }
                                    
-                    with open(os.path.join(path_trainable_params, node.name() + '.json'), 'w') as fp:
-                        json.dump(tensor_dict, fp, indent=4)
+                        with open(os.path.join(path_trainable_params, node.name() + '.json'), 'w') as fp:
+                            json.dump(tensor_dict, fp, indent=4)
 
-                    node_dict['tensor_data'] = os.path.join(path_trainable_params, node.name() + '.json')
+                        node_dict['tensor_data'] = os.path.join(path_trainable_params, node.name() + '.json')
 
+                    else:
+                        raise Exception("File format to write trainable parameters not recognized.")
+
+                
                 elif write_trainable_params_embed:
                     node_dict['tensor_data'] = np.array(node.get_operator().get_output(0)).tolist()
                 
-- 
GitLab