Skip to content
Snippets Groups Projects
Commit cbc82705 authored by Iryna DE ALBUQUERQUE SILVA's avatar Iryna DE ALBUQUERQUE SILVA Committed by Cyril Moineau
Browse files

Correct logic for Node's parents/children association with inputs/outputs in...

Correct logic for Node's parents/children association with inputs/outputs in aidge_core/show_graphview.py.
parent f4437b4d
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!243Correct logic for Node's parents/children association with inputs/outputs in aidge_core/show_graphview.py.
......@@ -79,29 +79,32 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e
if parents[0] is None: parents.append(parents.pop(0))
else:
pass
parents_inputs = []
for parent in parents:
input_idx = 0
for parent in node.get_parents():
if parent is not None:
for output_idx in range(parent.get_operator().nb_outputs()):
for input_idx in range(node.get_operator().nb_inputs()):
if parent.get_operator().get_output(output_idx).dims() == node.get_operator().get_input(input_idx).dims():
for children in parent.outputs():
for child in children:
if child[0] == node and child[1] == input_idx:
parents_inputs.append((parent.name(), input_idx))
elif parent is None:
for input_idx in list(range(node.get_operator().nb_inputs())):
if input_idx not in [item[1] for item in parents_inputs]:
parents_inputs.append((None, input_idx))
parents_inputs.sort(key=lambda x: x[1])
if input_idx not in [item[1] for item in parents_inputs]:
parents_inputs.append((None, input_idx))
input_idx += 1
node_dict['parents'] = parents_inputs
children_outputs = []
for child in node.get_children():
for input_idx in range(child.get_operator().nb_inputs()):
for output_idx in range(node.get_operator().nb_outputs()):
if child.get_operator().get_input(input_idx).dims() == node.get_operator().get_output(output_idx).dims():
children_outputs.append((child.name(), output_idx))
output_idx = 0
for children in node.get_ordered_children():
for child in children:
if child is not None:
for parent in child.inputs():
if parent[0] == node and parent[1] == output_idx:
children_outputs.append((child.name(), output_idx))
output_idx += 1
node_dict['children'] = children_outputs
# Check if my node is a metaop
......@@ -129,7 +132,7 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e
if params_file_format=='npz':
np.savez_compressed(Path(path_trainable_params, node.name()), **{node.name() : node.get_operator().get_output(0)})
node_dict['tensor_data'] = Path(path_trainable_params, node.name() + '.npz')
node_dict['tensor_data'] = str(Path(path_trainable_params, node.name() + '.npz'))
elif params_file_format=='json':
tensor = np.array(node.get_operator().get_output(0))
......@@ -145,13 +148,13 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e
with open(Path(path_trainable_params, node.name() + '.json'), 'w') as fp:
json.dump(tensor_dict, fp, indent=4)
node_dict['tensor_data'] = Path(path_trainable_params, node.name() + '.json')
node_dict['tensor_data'] = str(Path(path_trainable_params, node.name() + '.json'))
else:
raise Exception("File format to write trainable parameters not recognized.")
elif write_trainable_params_embed:
if write_trainable_params_embed:
node_dict['tensor_data'] = np.array(node.get_operator().get_output(0)).tolist()
else:
......@@ -195,17 +198,21 @@ def gview_to_json(gview : aidge_core.GraphView, json_path : Path, write_trainabl
:type params_file_format: str, optional
"""
if json_path.is_dir():
json_path = (json_path.parent).joinpath('model.json')
if not json_path.suffix:
if not json_path.is_dir():
json_path.mkdir(parents=True, exist_ok=True)
json_path = json_path.joinpath('model.json')
elif not json_path.is_dir():
if json_path.suffix == '.json':
pass
else:
raise Exception('If ``json_path`` contains a filename it must be of JSON format.')
else:
if json_path.suffix != '.json':
raise Exception('If ``json_path`` contains a filename, it must be of JSON format.')
if not json_path.parent.is_dir():
json_path.parent.mkdir(parents=True, exist_ok=True)
if write_trainable_params_ext:
path_trainable_params = (json_path.parent).joinpath(json_path.stem + '_trainable_params/')
path_trainable_params.mkdir(parents=True, exist_ok=True)
else:
path_trainable_params = Path()
......
......@@ -176,6 +176,11 @@ void init_Node(py::module& m) {
Get children.
)mydelimiter")
.def("get_ordered_children", &Node::getOrderedChildren,
R"mydelimiter(
Get ordered children.
)mydelimiter")
.def("__call__",
[](Node &self, pybind11::args args) {
std::vector<Connector> connectors;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment