Skip to content
Snippets Groups Projects
Commit cb4ce83f authored by Iryna DE ALBUQUERQUE SILVA's avatar Iryna DE ALBUQUERQUE SILVA
Browse files

Added exception for non recognized file format for writing trainable parameters in external file

parent 9d563182
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!211Add show_graphview funcionality.
......@@ -104,14 +104,14 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e
# Check if my node is a metaop
attributes_dict = {}
if isinstance(node.get_operator(), aidge_core.MetaOperator_Op):
attributes_dict['micro_graph'] = []
for micro_node in node.get_operator().get_micro_graph().get_nodes():
micro_node_dict = {'name' : micro_node.name(),
'optype' : micro_node.type()}
micro_node_attr_dict = _retrieve_operator_attrs(micro_node)
micro_node_dict['attributes'] = micro_node_attr_dict
attributes_dict['micro_graph'].append(micro_node_dict)
attributes_dict['micro_graph'] = []
for micro_node in node.get_operator().get_micro_graph().get_nodes():
micro_node_dict = {'name' : micro_node.name(),
'optype' : micro_node.type()}
micro_node_attr_dict = _retrieve_operator_attrs(micro_node)
micro_node_dict['attributes'] = micro_node_attr_dict
attributes_dict['micro_graph'].append(micro_node_dict)
else:
node_attr_dict = _retrieve_operator_attrs(node)
......@@ -120,26 +120,34 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e
node_dict['attributes'] = attributes_dict
if node.type() == 'Producer':
if write_trainable_params_ext and params_file_format=='npz':
np.savez_compressed(os.path.join(path_trainable_params, node.name()), **{node.name() : node.get_operator().get_output(0)})
node_dict['tensor_data'] = os.path.join(path_trainable_params, node.name() + '.npz')
elif write_trainable_params_ext and params_file_format=='json':
tensor = np.array(node.get_operator().get_output(0))
tensor_dict = {
node.name() :
{
'dims' : tensor.shape,
'data_type' : str(tensor.dtype),
'tensor_data' : tensor.tolist()
}
}
if write_trainable_params_ext:
params_file_format.casefold()
if params_file_format=='npz':
np.savez_compressed(os.path.join(path_trainable_params, node.name()), **{node.name() : node.get_operator().get_output(0)})
node_dict['tensor_data'] = os.path.join(path_trainable_params, node.name() + '.npz')
elif params_file_format=='json':
tensor = np.array(node.get_operator().get_output(0))
tensor_dict = {
node.name() :
{
'dims' : tensor.shape,
'data_type' : str(tensor.dtype),
'tensor_data' : tensor.tolist()
}
}
with open(os.path.join(path_trainable_params, node.name() + '.json'), 'w') as fp:
json.dump(tensor_dict, fp, indent=4)
with open(os.path.join(path_trainable_params, node.name() + '.json'), 'w') as fp:
json.dump(tensor_dict, fp, indent=4)
node_dict['tensor_data'] = os.path.join(path_trainable_params, node.name() + '.json')
node_dict['tensor_data'] = os.path.join(path_trainable_params, node.name() + '.json')
else:
raise Exception("File format to write trainable parameters not recognized.")
elif write_trainable_params_embed:
node_dict['tensor_data'] = np.array(node.get_operator().get_output(0)).tolist()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment