Skip to content
Snippets Groups Projects

Correct logic for Node's parents/children association with inputs/outputs in aidge_core/show_graphview.py.

Merged Iryna de Albuquerque Silva requested to merge idealbuq/aidge_core:dev into dev
1 file
+ 20
17
Compare changes
  • Side-by-side
  • Inline
+ 20
17
@@ -79,29 +79,32 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e
if parents[0] is None: parents.append(parents.pop(0))
else:
pass
parents_inputs = []
for parent in parents:
input_idx = 0
for parent in node.get_parents():
if parent is not None:
for output_idx in range(parent.get_operator().nb_outputs()):
for input_idx in range(node.get_operator().nb_inputs()):
if parent.get_operator().get_output(output_idx).dims() == node.get_operator().get_input(input_idx).dims():
for children in parent.outputs():
for child in children:
if child[0] == node and child[1] == input_idx:
parents_inputs.append((parent.name(), input_idx))
elif parent is None:
for input_idx in list(range(node.get_operator().nb_inputs())):
if input_idx not in [item[1] for item in parents_inputs]:
parents_inputs.append((None, input_idx))
parents_inputs.sort(key=lambda x: x[1])
if input_idx not in [item[1] for item in parents_inputs]:
parents_inputs.append((None, input_idx))
input_idx += 1
node_dict['parents'] = parents_inputs
children_outputs = []
for child in node.get_children():
for input_idx in range(child.get_operator().nb_inputs()):
for output_idx in range(node.get_operator().nb_outputs()):
if child.get_operator().get_input(input_idx).dims() == node.get_operator().get_output(output_idx).dims():
children_outputs.append((child.name(), output_idx))
output_idx = 0
for children in node.get_ordered_children():
for child in children:
if child is not None:
for parent in child.inputs():
if parent[0] == node and parent[1] == output_idx:
children_outputs.append((child.name(), output_idx))
output_idx += 1
node_dict['children'] = children_outputs
# Check if my node is a metaop
@@ -151,7 +154,7 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e
raise Exception("File format to write trainable parameters not recognized.")
elif write_trainable_params_embed:
if write_trainable_params_embed:
node_dict['tensor_data'] = np.array(node.get_operator().get_output(0)).tolist()
else:
Loading