diff --git a/aidge_core/show_graphview.py b/aidge_core/show_graphview.py index 9fcde55732c3c21c922d0e03f627c4ef66df8ab2..3d00d576ae1f4b4bbd1bd1b4bb966eed3c2a9764 100644 --- a/aidge_core/show_graphview.py +++ b/aidge_core/show_graphview.py @@ -4,73 +4,11 @@ import builtins import aidge_core import numpy as np -def dfs(graph : aidge_core.GraphView, node : aidge_core.Node, visited_nodes : list[aidge_core.Node], sorted_nodes : list[aidge_core.Node]) -> None: - """ - Performs the depth-first search algorithm for topological sorting. - - :param graph: An unsorted GraphView of Aidge. - :type graph: aidge_core.GraphView - :param node: The GraphView's Node that is being treated. - :type node: aidge_core.Node - :param visited_nodes: List of nodes that have already been visited. - :type visited_nodes: list[aidge_core.Node] - :param sorted_nodes: List of nodes that have already been sorted. - :type sorted_nodes: list[aidge_core.Node] - """ - node_children = node.get_children() - - visited_nodes.add(node) - - for child in node_children: - if child not in visited_nodes: - dfs(graph, child, visited_nodes, sorted_nodes) - - sorted_nodes.append(node) - - # Make sure Producers are treated: - parents = [] - for parent in node.get_parents(): - try: - has_parents = parent.get_parents() - - except AttributeError: - has_parents = False - - if (not has_parents) and (parent not in sorted_nodes): - parents.append(parent) - - parents.reverse() - sorted_nodes.extend(parents) - - return None - -def topological_sort(graph : aidge_core.GraphView) -> list[aidge_core.Node]: - """ - Performs topological sorting by applying depth-first search algorithm recursively. - - :param graph: An unsorted GraphView of Aidge. - :type graph: aidge_core.GraphView - :return: A list with the GraphView's sorted nodes. - :rtype: list[aidge_core.Node] - """ - - input_nodes = graph.get_input_nodes() - visited_nodes = set() - sorted_nodes = [] - - for input in input_nodes: - if input not in visited_nodes: - dfs(input_nodes, input, visited_nodes, sorted_nodes) - - sorted_nodes.reverse() - - return sorted_nodes - -def retrieve_operator_attrs(node : aidge_core.Node) -> dict[str, int, float, bool, None]: +def _retrieve_operator_attrs(node : aidge_core.Node) -> dict[str, int, float, bool, None]: """ Returns the dictionary containing the attributes of a given Node. - :param graph: A Node in the list of sorted nodes. + :param graph: A Node in the list of ordered nodes. :type graph: aidge_core.Node :return: A dictionary with the Node's attributes. :rtype: dict[str, int, float, bool, None] @@ -86,11 +24,11 @@ def retrieve_operator_attrs(node : aidge_core.Node) -> dict[str, int, float, boo return node_attr_dict -def create_dict(sorted_nodes : list[aidge_core.Node], write_trainable_params_ext : bool, write_trainable_params_embed : bool, params_file_format : str, path_trainable_params : str) -> dict[str, int, float, bool, None]: +def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_ext : bool, write_trainable_params_embed : bool, params_file_format : str, path_trainable_params : str) -> dict[str, int, float, bool, None]: """ - Creates a dictionary to store the information of a given sorted GraphView. + Creates a dictionary to store the information of a given ordered GraphView. - :param sorted_nodes: A list with the GraphView's sorted nodes. + :param ordered_nodes: A list with the GraphView's ordered nodes. :type graph: list :param write_trainable_params_ext: Whether or not to write the eventual trainable parameters of the Nodes in an external file. :type write_trainable_params_ext: bool @@ -107,7 +45,7 @@ def create_dict(sorted_nodes : list[aidge_core.Node], write_trainable_params_ext graphview_dict = {'graph': []} - for node in sorted_nodes: + for node in ordered_nodes: if node is not None: node_dict = {'name' : node.name(), @@ -171,12 +109,12 @@ def create_dict(sorted_nodes : list[aidge_core.Node], write_trainable_params_ext micro_node_dict = {'name' : micro_node.name(), 'optype' : micro_node.type()} - micro_node_attr_dict = retrieve_operator_attrs(micro_node) + micro_node_attr_dict = _retrieve_operator_attrs(micro_node) micro_node_dict['attributes'] = micro_node_attr_dict attributes_dict['micro_graph'].append(micro_node_dict) else: - node_attr_dict = retrieve_operator_attrs(node) + node_attr_dict = _retrieve_operator_attrs(node) attributes_dict.update(node_attr_dict) node_dict['attributes'] = attributes_dict @@ -215,7 +153,7 @@ def create_dict(sorted_nodes : list[aidge_core.Node], write_trainable_params_ext return graphview_dict -def write_dict_json(graphview_dict : dict[str, int, float, bool, None], json_path : str) -> None: +def _write_dict_json(graphview_dict : dict[str, int, float, bool, None], json_path : str) -> None: """ Writes dictionary containing GraphView description to a JSON file. @@ -255,13 +193,13 @@ def gview_to_json(gview : aidge_core.GraphView, json_path : str, write_trainable else: path_trainable_params = '' - # Sort graphview - sorted_nodes = topological_sort(gview) + # Sort GraphView in topological order + ordered_nodes = gview.get_ordered_nodes() - # Create dict from graphview - graphview_dict = create_dict(sorted_nodes, write_trainable_params_ext, write_trainable_params_embed, params_file_format, path_trainable_params) + # Create dict from GraphView + graphview_dict = _create_dict(ordered_nodes, write_trainable_params_ext, write_trainable_params_embed, params_file_format, path_trainable_params) - # Write dict to json - write_dict_json(graphview_dict, json_path) + # Write dict to JSON + _write_dict_json(graphview_dict, json_path) return None \ No newline at end of file diff --git a/aidge_core/unit_tests/test_show_graphview.py b/aidge_core/unit_tests/test_show_graphview.py index ccdc97febb879219b00a89603d9f0aa6b48295ff..58547301230f589723615e88f35358f470b536f5 100644 --- a/aidge_core/unit_tests/test_show_graphview.py +++ b/aidge_core/unit_tests/test_show_graphview.py @@ -50,7 +50,7 @@ def create_gview(): return gview class test_show_gview(unittest.TestCase): - """Test aidge show GraphView + """Test aidge functionality to show GraphView. """ def setUp(self): @@ -59,9 +59,8 @@ class test_show_gview(unittest.TestCase): def tearDown(self): pass - def test_model_to_json(self): + def test_gview_to_json(self): - gview = create_gview() # Create temporary file to store JSON model description @@ -74,58 +73,55 @@ class test_show_gview(unittest.TestCase): model_json = json.load(fp) # Get list of nodes of Aidge graphview - gview_ranked_nodes = gview.get_ranked_nodes()[0] - - # Iterate over ranked_nodes + gview_ordered_nodes = gview.get_ordered_nodes() - self.assertEqual(len(gview_ranked_nodes), len(model_json['graph'])) + # Iterate over the list of ordered nodes and the corresponding JSON + self.assertEqual(len(gview_ordered_nodes), len(model_json['graph'])) - for node_gview in gview_ranked_nodes: - for node_json in model_json['graph']: - if node_gview.name() == node_json['name']: + for node_gview, node_json in zip(gview_ordered_nodes, model_json['graph']): - self.assertEqual(node_gview.get_operator().type(), node_json['optype']) - self.assertEqual(node_gview.get_operator().nb_inputs(), node_json['nb_inputs']) - self.assertEqual(node_gview.get_operator().nb_outputs(), node_json['nb_outputs']) - - self.assertEqual(node_gview.get_operator().nb_inputs(), len(node_json['inputs'])) - for input_idx in range(node_gview.get_operator().nb_inputs()): - self.assertEqual(node_gview.get_operator().get_input(input_idx).dims(), node_json['inputs'][input_idx]['dims']) - self.assertEqual(str(node_gview.get_operator().get_input(input_idx).dtype()), node_json['inputs'][input_idx]['data_type']) - self.assertEqual(str(node_gview.get_operator().get_input(input_idx).dformat()), node_json['inputs'][input_idx]['data_format']) - - self.assertEqual(node_gview.get_operator().nb_outputs(), len(node_json['outputs'])) - for output_idx in range(node_gview.get_operator().nb_outputs()): - self.assertEqual(node_gview.get_operator().get_output(output_idx).dims(), node_json['outputs'][output_idx]['dims']) - self.assertEqual(str(node_gview.get_operator().get_output(output_idx).dtype()), node_json['outputs'][output_idx]['data_type']) - self.assertEqual(str(node_gview.get_operator().get_output(output_idx).dformat()), node_json['outputs'][output_idx]['data_format']) - - self.assertEqual(len(node_gview.get_parents()), len(node_json['parents'])) - self.assertEqual(len(node_gview.get_children()), len(node_json['children'])) - - if not hasattr(node_gview.get_operator(), 'get_micro_graph'): - try: - self.assertEqual(len(node_gview.get_operator().attr.dict()), len(node_json['attributes'])) - self.assertDictEqual(node_gview.get_operator().attr.dict(), node_json['attributes']) - - except AttributeError: - self.assertIsNone(node_gview.get_operator().attr) and self.assertFalse(node_json['attributes']) - - elif hasattr(node_gview.get_operator(), 'get_micro_graph'): - - self.assertEqual(len(node_gview.get_operator().get_micro_graph().get_nodes()), len(node_json['attributes']['micro_graph'])) - - for micro_node_gview in node_gview.get_operator().get_micro_graph().get_nodes(): - for micro_node_json in node_json['attributes']['micro_graph']: - if micro_node_gview.get_operator().type() == micro_node_json['optype']: - - for key, value in micro_node_gview.get_operator().attr.dict().items(): - if not type(value).__name__ in dir(builtins): - # Replace original value by its name (str) because value is of a type that could not be written to the JSON - # Cannot update this dict inplace : micro_node_gview.get_operator().attr.dict().update({key : value.name}) - temp_mnode_dict = micro_node_gview.get_operator().attr.dict() - temp_mnode_dict.update({key : value.name}) - self.assertDictEqual(temp_mnode_dict, micro_node_json['attributes']) + self.assertEqual(node_gview.get_operator().type(), node_json['optype']) + self.assertEqual(node_gview.get_operator().nb_inputs(), node_json['nb_inputs']) + self.assertEqual(node_gview.get_operator().nb_outputs(), node_json['nb_outputs']) + + self.assertEqual(node_gview.get_operator().nb_inputs(), len(node_json['inputs'])) + for input_idx in range(node_gview.get_operator().nb_inputs()): + self.assertEqual(node_gview.get_operator().get_input(input_idx).dims(), node_json['inputs'][input_idx]['dims']) + self.assertEqual(str(node_gview.get_operator().get_input(input_idx).dtype()), node_json['inputs'][input_idx]['data_type']) + self.assertEqual(str(node_gview.get_operator().get_input(input_idx).dformat()), node_json['inputs'][input_idx]['data_format']) + + self.assertEqual(node_gview.get_operator().nb_outputs(), len(node_json['outputs'])) + for output_idx in range(node_gview.get_operator().nb_outputs()): + self.assertEqual(node_gview.get_operator().get_output(output_idx).dims(), node_json['outputs'][output_idx]['dims']) + self.assertEqual(str(node_gview.get_operator().get_output(output_idx).dtype()), node_json['outputs'][output_idx]['data_type']) + self.assertEqual(str(node_gview.get_operator().get_output(output_idx).dformat()), node_json['outputs'][output_idx]['data_format']) + + self.assertEqual(len(node_gview.get_parents()), len(node_json['parents'])) + self.assertEqual(len(node_gview.get_children()), len(node_json['children'])) + + if not hasattr(node_gview.get_operator(), 'get_micro_graph'): + try: + self.assertEqual(len(node_gview.get_operator().attr.dict()), len(node_json['attributes'])) + self.assertDictEqual(node_gview.get_operator().attr.dict(), node_json['attributes']) + + except AttributeError: + self.assertIsNone(node_gview.get_operator().attr) and self.assertFalse(node_json['attributes']) + + elif hasattr(node_gview.get_operator(), 'get_micro_graph'): + + self.assertEqual(len(node_gview.get_operator().get_micro_graph().get_nodes()), len(node_json['attributes']['micro_graph'])) + + for micro_node_gview in node_gview.get_operator().get_micro_graph().get_nodes(): + for micro_node_json in node_json['attributes']['micro_graph']: + if micro_node_gview.get_operator().type() == micro_node_json['optype']: + + for key, value in micro_node_gview.get_operator().attr.dict().items(): + if not type(value).__name__ in dir(builtins): + # Replace original value by its name (str) because value is of a type that could not be written to the JSON + # Cannot update this dict inplace : micro_node_gview.get_operator().attr.dict().update({key : value.name}) + temp_mnode_dict = micro_node_gview.get_operator().attr.dict() + temp_mnode_dict.update({key : value.name}) + self.assertDictEqual(temp_mnode_dict, micro_node_json['attributes']) if __name__ == '__main__': unittest.main()