From cbc8270555e90e4c6440733caf430f48b141a4cb Mon Sep 17 00:00:00 2001
From: Iryna DE ALBUQUERQUE SILVA <iryna.dealbuquerquesilva@cea.fr>
Date: Thu, 7 Nov 2024 14:34:20 +0000
Subject: [PATCH] Correct logic for Node's parents/children association with
 inputs/outputs in aidge_core/show_graphview.py.

---
 aidge_core/show_graphview.py         | 59 ++++++++++++++++------------
 python_binding/graph/pybind_Node.cpp |  5 +++
 2 files changed, 38 insertions(+), 26 deletions(-)

diff --git a/aidge_core/show_graphview.py b/aidge_core/show_graphview.py
index 633298f10..4f6a29603 100644
--- a/aidge_core/show_graphview.py
+++ b/aidge_core/show_graphview.py
@@ -79,29 +79,32 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e
                 if parents[0] is None: parents.append(parents.pop(0))
             else:
                 pass
-
+           
             parents_inputs = []
-            for parent in parents:
+            input_idx = 0
+            for parent in node.get_parents():
                 if parent is not None:
-                    for output_idx in range(parent.get_operator().nb_outputs()):
-                        for input_idx in range(node.get_operator().nb_inputs()):
-                            if parent.get_operator().get_output(output_idx).dims() == node.get_operator().get_input(input_idx).dims():
+                    for children in parent.outputs():
+                        for child in children:
+                            if child[0] == node and child[1] == input_idx:
                                 parents_inputs.append((parent.name(), input_idx))
-
+                
                 elif parent is None:
-                    for input_idx in list(range(node.get_operator().nb_inputs())):
-                        if input_idx not in [item[1] for item in parents_inputs]:
-                                parents_inputs.append((None, input_idx))
-
-            parents_inputs.sort(key=lambda x: x[1])
+                    if input_idx not in [item[1] for item in parents_inputs]:
+                        parents_inputs.append((None, input_idx))
+                
+                input_idx += 1
             node_dict['parents'] = parents_inputs
 
             children_outputs = []
-            for child in node.get_children():
-                for input_idx in range(child.get_operator().nb_inputs()):
-                    for output_idx in range(node.get_operator().nb_outputs()):
-                        if child.get_operator().get_input(input_idx).dims() == node.get_operator().get_output(output_idx).dims():
-                            children_outputs.append((child.name(), output_idx))
+            output_idx = 0
+            for children in node.get_ordered_children():
+                for child in children:
+                    if child is not None:
+                        for parent in child.inputs():
+                            if parent[0] == node and parent[1] == output_idx:
+                                children_outputs.append((child.name(), output_idx))
+                output_idx += 1
             node_dict['children'] = children_outputs
 
             # Check if my node is a metaop
@@ -129,7 +132,7 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e
 
                     if params_file_format=='npz':
                         np.savez_compressed(Path(path_trainable_params, node.name()), **{node.name() : node.get_operator().get_output(0)})
-                        node_dict['tensor_data'] = Path(path_trainable_params, node.name() + '.npz')
+                        node_dict['tensor_data'] = str(Path(path_trainable_params, node.name() + '.npz'))
 
                     elif params_file_format=='json':
                         tensor = np.array(node.get_operator().get_output(0))
@@ -145,13 +148,13 @@ def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_e
                         with open(Path(path_trainable_params, node.name() + '.json'), 'w') as fp:
                             json.dump(tensor_dict, fp, indent=4)
 
-                        node_dict['tensor_data'] = Path(path_trainable_params, node.name() + '.json')
+                        node_dict['tensor_data'] = str(Path(path_trainable_params, node.name() + '.json'))
 
                     else:
                         raise Exception("File format to write trainable parameters not recognized.")
 
 
-                elif write_trainable_params_embed:
+                if write_trainable_params_embed:
                     node_dict['tensor_data'] = np.array(node.get_operator().get_output(0)).tolist()
 
                 else:
@@ -195,17 +198,21 @@ def gview_to_json(gview : aidge_core.GraphView, json_path : Path, write_trainabl
     :type params_file_format: str, optional
     """
 
-    if json_path.is_dir():
-        json_path = (json_path.parent).joinpath('model.json')
+    if not json_path.suffix:
+        if not json_path.is_dir():
+            json_path.mkdir(parents=True, exist_ok=True)
+        json_path = json_path.joinpath('model.json')
 
-    elif not json_path.is_dir():
-        if json_path.suffix == '.json':
-            pass
-        else:
-            raise Exception('If ``json_path`` contains a filename it must be of JSON format.')
+    else:
+        if json_path.suffix != '.json':
+            raise Exception('If ``json_path`` contains a filename, it must be of JSON format.')
+        if not json_path.parent.is_dir():
+            json_path.parent.mkdir(parents=True, exist_ok=True)
 
     if write_trainable_params_ext:
         path_trainable_params = (json_path.parent).joinpath(json_path.stem +  '_trainable_params/')
+        path_trainable_params.mkdir(parents=True, exist_ok=True)
+
     else:
         path_trainable_params = Path()
 
diff --git a/python_binding/graph/pybind_Node.cpp b/python_binding/graph/pybind_Node.cpp
index 35f632744..69a28960b 100644
--- a/python_binding/graph/pybind_Node.cpp
+++ b/python_binding/graph/pybind_Node.cpp
@@ -176,6 +176,11 @@ void init_Node(py::module& m) {
     Get children.
     )mydelimiter")
 
+    .def("get_ordered_children", &Node::getOrderedChildren,
+    R"mydelimiter(
+    Get ordered children.
+    )mydelimiter")
+
     .def("__call__",
         [](Node &self, pybind11::args args) {
             std::vector<Connector> connectors;
-- 
GitLab