From cb4ce83ff1eff145085dab94537ec5006817aa53 Mon Sep 17 00:00:00 2001 From: idealbuq <iryna.dealbuquerquesilva@cea.fr> Date: Tue, 1 Oct 2024 09:51:45 +0000 Subject: [PATCH] Added exception for non recognized file format for writing trainable parameters in external file --- aidge_core/show_graphview.py | 58 ++++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/aidge_core/show_graphview.py b/aidge_core/show_graphview.py index 94669110a..dfc0acc45 100644 --- a/aidge_core/show_graphview.py +++ b/aidge_core/show_graphview.py @@ -104,14 +104,14 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e # Check if my node is a metaop attributes_dict = {} if isinstance(node.get_operator(), aidge_core.MetaOperator_Op): - attributes_dict['micro_graph'] = [] - for micro_node in node.get_operator().get_micro_graph().get_nodes(): - micro_node_dict = {'name' : micro_node.name(), - 'optype' : micro_node.type()} - - micro_node_attr_dict = _retrieve_operator_attrs(micro_node) - micro_node_dict['attributes'] = micro_node_attr_dict - attributes_dict['micro_graph'].append(micro_node_dict) + attributes_dict['micro_graph'] = [] + for micro_node in node.get_operator().get_micro_graph().get_nodes(): + micro_node_dict = {'name' : micro_node.name(), + 'optype' : micro_node.type()} + + micro_node_attr_dict = _retrieve_operator_attrs(micro_node) + micro_node_dict['attributes'] = micro_node_attr_dict + attributes_dict['micro_graph'].append(micro_node_dict) else: node_attr_dict = _retrieve_operator_attrs(node) @@ -120,26 +120,34 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e node_dict['attributes'] = attributes_dict if node.type() == 'Producer': - if write_trainable_params_ext and params_file_format=='npz': - np.savez_compressed(os.path.join(path_trainable_params, node.name()), **{node.name() : node.get_operator().get_output(0)}) - node_dict['tensor_data'] = os.path.join(path_trainable_params, node.name() + '.npz') - - elif write_trainable_params_ext and params_file_format=='json': - tensor = np.array(node.get_operator().get_output(0)) - tensor_dict = { - node.name() : - { - 'dims' : tensor.shape, - 'data_type' : str(tensor.dtype), - 'tensor_data' : tensor.tolist() - } - } + if write_trainable_params_ext: + + params_file_format.casefold() + + if params_file_format=='npz': + np.savez_compressed(os.path.join(path_trainable_params, node.name()), **{node.name() : node.get_operator().get_output(0)}) + node_dict['tensor_data'] = os.path.join(path_trainable_params, node.name() + '.npz') + + elif params_file_format=='json': + tensor = np.array(node.get_operator().get_output(0)) + tensor_dict = { + node.name() : + { + 'dims' : tensor.shape, + 'data_type' : str(tensor.dtype), + 'tensor_data' : tensor.tolist() + } + } - with open(os.path.join(path_trainable_params, node.name() + '.json'), 'w') as fp: - json.dump(tensor_dict, fp, indent=4) + with open(os.path.join(path_trainable_params, node.name() + '.json'), 'w') as fp: + json.dump(tensor_dict, fp, indent=4) - node_dict['tensor_data'] = os.path.join(path_trainable_params, node.name() + '.json') + node_dict['tensor_data'] = os.path.join(path_trainable_params, node.name() + '.json') + else: + raise Exception("File format to write trainable parameters not recognized.") + + elif write_trainable_params_embed: node_dict['tensor_data'] = np.array(node.get_operator().get_output(0)).tolist() -- GitLab