From cbc8270555e90e4c6440733caf430f48b141a4cb Mon Sep 17 00:00:00 2001 From: Iryna DE ALBUQUERQUE SILVA <iryna.dealbuquerquesilva@cea.fr> Date: Thu, 7 Nov 2024 14:34:20 +0000 Subject: [PATCH] Correct logic for Node's parents/children association with inputs/outputs in aidge_core/show_graphview.py. --- aidge_core/show_graphview.py | 59 ++++++++++++++++------------ python_binding/graph/pybind_Node.cpp | 5 +++ 2 files changed, 38 insertions(+), 26 deletions(-) diff --git a/aidge_core/show_graphview.py b/aidge_core/show_graphview.py index 633298f10..4f6a29603 100644 --- a/aidge_core/show_graphview.py +++ b/aidge_core/show_graphview.py @@ -79,29 +79,32 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e if parents[0] is None: parents.append(parents.pop(0)) else: pass - + parents_inputs = [] - for parent in parents: + input_idx = 0 + for parent in node.get_parents(): if parent is not None: - for output_idx in range(parent.get_operator().nb_outputs()): - for input_idx in range(node.get_operator().nb_inputs()): - if parent.get_operator().get_output(output_idx).dims() == node.get_operator().get_input(input_idx).dims(): + for children in parent.outputs(): + for child in children: + if child[0] == node and child[1] == input_idx: parents_inputs.append((parent.name(), input_idx)) - + elif parent is None: - for input_idx in list(range(node.get_operator().nb_inputs())): - if input_idx not in [item[1] for item in parents_inputs]: - parents_inputs.append((None, input_idx)) - - parents_inputs.sort(key=lambda x: x[1]) + if input_idx not in [item[1] for item in parents_inputs]: + parents_inputs.append((None, input_idx)) + + input_idx += 1 node_dict['parents'] = parents_inputs children_outputs = [] - for child in node.get_children(): - for input_idx in range(child.get_operator().nb_inputs()): - for output_idx in range(node.get_operator().nb_outputs()): - if child.get_operator().get_input(input_idx).dims() == node.get_operator().get_output(output_idx).dims(): - children_outputs.append((child.name(), output_idx)) + output_idx = 0 + for children in node.get_ordered_children(): + for child in children: + if child is not None: + for parent in child.inputs(): + if parent[0] == node and parent[1] == output_idx: + children_outputs.append((child.name(), output_idx)) + output_idx += 1 node_dict['children'] = children_outputs # Check if my node is a metaop @@ -129,7 +132,7 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e if params_file_format=='npz': np.savez_compressed(Path(path_trainable_params, node.name()), **{node.name() : node.get_operator().get_output(0)}) - node_dict['tensor_data'] = Path(path_trainable_params, node.name() + '.npz') + node_dict['tensor_data'] = str(Path(path_trainable_params, node.name() + '.npz')) elif params_file_format=='json': tensor = np.array(node.get_operator().get_output(0)) @@ -145,13 +148,13 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e with open(Path(path_trainable_params, node.name() + '.json'), 'w') as fp: json.dump(tensor_dict, fp, indent=4) - node_dict['tensor_data'] = Path(path_trainable_params, node.name() + '.json') + node_dict['tensor_data'] = str(Path(path_trainable_params, node.name() + '.json')) else: raise Exception("File format to write trainable parameters not recognized.") - elif write_trainable_params_embed: + if write_trainable_params_embed: node_dict['tensor_data'] = np.array(node.get_operator().get_output(0)).tolist() else: @@ -195,17 +198,21 @@ def gview_to_json(gview : aidge_core.GraphView, json_path : Path, write_trainabl :type params_file_format: str, optional """ - if json_path.is_dir(): - json_path = (json_path.parent).joinpath('model.json') + if not json_path.suffix: + if not json_path.is_dir(): + json_path.mkdir(parents=True, exist_ok=True) + json_path = json_path.joinpath('model.json') - elif not json_path.is_dir(): - if json_path.suffix == '.json': - pass - else: - raise Exception('If ``json_path`` contains a filename it must be of JSON format.') + else: + if json_path.suffix != '.json': + raise Exception('If ``json_path`` contains a filename, it must be of JSON format.') + if not json_path.parent.is_dir(): + json_path.parent.mkdir(parents=True, exist_ok=True) if write_trainable_params_ext: path_trainable_params = (json_path.parent).joinpath(json_path.stem + '_trainable_params/') + path_trainable_params.mkdir(parents=True, exist_ok=True) + else: path_trainable_params = Path() diff --git a/python_binding/graph/pybind_Node.cpp b/python_binding/graph/pybind_Node.cpp index 35f632744..69a28960b 100644 --- a/python_binding/graph/pybind_Node.cpp +++ b/python_binding/graph/pybind_Node.cpp @@ -176,6 +176,11 @@ void init_Node(py::module& m) { Get children. )mydelimiter") + .def("get_ordered_children", &Node::getOrderedChildren, + R"mydelimiter( + Get ordered children. + )mydelimiter") + .def("__call__", [](Node &self, pybind11::args args) { std::vector<Connector> connectors; -- GitLab