Skip to content
Snippets Groups Projects
Commit 19398e4a authored by Iryna de Albuquerque Silva's avatar Iryna de Albuquerque Silva
Browse files

Replaced local function for topological ordering by new GraphView's get_ordered_nodes() method

parent c6654d1e
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!211Add show_graphview funcionality.
...@@ -4,73 +4,11 @@ import builtins ...@@ -4,73 +4,11 @@ import builtins
import aidge_core import aidge_core
import numpy as np import numpy as np
def dfs(graph : aidge_core.GraphView, node : aidge_core.Node, visited_nodes : list[aidge_core.Node], sorted_nodes : list[aidge_core.Node]) -> None: def _retrieve_operator_attrs(node : aidge_core.Node) -> dict[str, int, float, bool, None]:
"""
Performs the depth-first search algorithm for topological sorting.
:param graph: An unsorted GraphView of Aidge.
:type graph: aidge_core.GraphView
:param node: The GraphView's Node that is being treated.
:type node: aidge_core.Node
:param visited_nodes: List of nodes that have already been visited.
:type visited_nodes: list[aidge_core.Node]
:param sorted_nodes: List of nodes that have already been sorted.
:type sorted_nodes: list[aidge_core.Node]
"""
node_children = node.get_children()
visited_nodes.add(node)
for child in node_children:
if child not in visited_nodes:
dfs(graph, child, visited_nodes, sorted_nodes)
sorted_nodes.append(node)
# Make sure Producers are treated:
parents = []
for parent in node.get_parents():
try:
has_parents = parent.get_parents()
except AttributeError:
has_parents = False
if (not has_parents) and (parent not in sorted_nodes):
parents.append(parent)
parents.reverse()
sorted_nodes.extend(parents)
return None
def topological_sort(graph : aidge_core.GraphView) -> list[aidge_core.Node]:
"""
Performs topological sorting by applying depth-first search algorithm recursively.
:param graph: An unsorted GraphView of Aidge.
:type graph: aidge_core.GraphView
:return: A list with the GraphView's sorted nodes.
:rtype: list[aidge_core.Node]
"""
input_nodes = graph.get_input_nodes()
visited_nodes = set()
sorted_nodes = []
for input in input_nodes:
if input not in visited_nodes:
dfs(input_nodes, input, visited_nodes, sorted_nodes)
sorted_nodes.reverse()
return sorted_nodes
def retrieve_operator_attrs(node : aidge_core.Node) -> dict[str, int, float, bool, None]:
""" """
Returns the dictionary containing the attributes of a given Node. Returns the dictionary containing the attributes of a given Node.
:param graph: A Node in the list of sorted nodes. :param graph: A Node in the list of ordered nodes.
:type graph: aidge_core.Node :type graph: aidge_core.Node
:return: A dictionary with the Node's attributes. :return: A dictionary with the Node's attributes.
:rtype: dict[str, int, float, bool, None] :rtype: dict[str, int, float, bool, None]
...@@ -86,11 +24,11 @@ def retrieve_operator_attrs(node : aidge_core.Node) -> dict[str, int, float, boo ...@@ -86,11 +24,11 @@ def retrieve_operator_attrs(node : aidge_core.Node) -> dict[str, int, float, boo
return node_attr_dict return node_attr_dict
def create_dict(sorted_nodes : list[aidge_core.Node], write_trainable_params_ext : bool, write_trainable_params_embed : bool, params_file_format : str, path_trainable_params : str) -> dict[str, int, float, bool, None]: def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_ext : bool, write_trainable_params_embed : bool, params_file_format : str, path_trainable_params : str) -> dict[str, int, float, bool, None]:
""" """
Creates a dictionary to store the information of a given sorted GraphView. Creates a dictionary to store the information of a given ordered GraphView.
:param sorted_nodes: A list with the GraphView's sorted nodes. :param ordered_nodes: A list with the GraphView's ordered nodes.
:type graph: list :type graph: list
:param write_trainable_params_ext: Whether or not to write the eventual trainable parameters of the Nodes in an external file. :param write_trainable_params_ext: Whether or not to write the eventual trainable parameters of the Nodes in an external file.
:type write_trainable_params_ext: bool :type write_trainable_params_ext: bool
...@@ -107,7 +45,7 @@ def create_dict(sorted_nodes : list[aidge_core.Node], write_trainable_params_ext ...@@ -107,7 +45,7 @@ def create_dict(sorted_nodes : list[aidge_core.Node], write_trainable_params_ext
graphview_dict = {'graph': []} graphview_dict = {'graph': []}
for node in sorted_nodes: for node in ordered_nodes:
if node is not None: if node is not None:
node_dict = {'name' : node.name(), node_dict = {'name' : node.name(),
...@@ -171,12 +109,12 @@ def create_dict(sorted_nodes : list[aidge_core.Node], write_trainable_params_ext ...@@ -171,12 +109,12 @@ def create_dict(sorted_nodes : list[aidge_core.Node], write_trainable_params_ext
micro_node_dict = {'name' : micro_node.name(), micro_node_dict = {'name' : micro_node.name(),
'optype' : micro_node.type()} 'optype' : micro_node.type()}
micro_node_attr_dict = retrieve_operator_attrs(micro_node) micro_node_attr_dict = _retrieve_operator_attrs(micro_node)
micro_node_dict['attributes'] = micro_node_attr_dict micro_node_dict['attributes'] = micro_node_attr_dict
attributes_dict['micro_graph'].append(micro_node_dict) attributes_dict['micro_graph'].append(micro_node_dict)
else: else:
node_attr_dict = retrieve_operator_attrs(node) node_attr_dict = _retrieve_operator_attrs(node)
attributes_dict.update(node_attr_dict) attributes_dict.update(node_attr_dict)
node_dict['attributes'] = attributes_dict node_dict['attributes'] = attributes_dict
...@@ -215,7 +153,7 @@ def create_dict(sorted_nodes : list[aidge_core.Node], write_trainable_params_ext ...@@ -215,7 +153,7 @@ def create_dict(sorted_nodes : list[aidge_core.Node], write_trainable_params_ext
return graphview_dict return graphview_dict
def write_dict_json(graphview_dict : dict[str, int, float, bool, None], json_path : str) -> None: def _write_dict_json(graphview_dict : dict[str, int, float, bool, None], json_path : str) -> None:
""" """
Writes dictionary containing GraphView description to a JSON file. Writes dictionary containing GraphView description to a JSON file.
...@@ -255,13 +193,13 @@ def gview_to_json(gview : aidge_core.GraphView, json_path : str, write_trainable ...@@ -255,13 +193,13 @@ def gview_to_json(gview : aidge_core.GraphView, json_path : str, write_trainable
else: else:
path_trainable_params = '' path_trainable_params = ''
# Sort graphview # Sort GraphView in topological order
sorted_nodes = topological_sort(gview) ordered_nodes = gview.get_ordered_nodes()
# Create dict from graphview # Create dict from GraphView
graphview_dict = create_dict(sorted_nodes, write_trainable_params_ext, write_trainable_params_embed, params_file_format, path_trainable_params) graphview_dict = _create_dict(ordered_nodes, write_trainable_params_ext, write_trainable_params_embed, params_file_format, path_trainable_params)
# Write dict to json # Write dict to JSON
write_dict_json(graphview_dict, json_path) _write_dict_json(graphview_dict, json_path)
return None return None
\ No newline at end of file
...@@ -50,7 +50,7 @@ def create_gview(): ...@@ -50,7 +50,7 @@ def create_gview():
return gview return gview
class test_show_gview(unittest.TestCase): class test_show_gview(unittest.TestCase):
"""Test aidge show GraphView """Test aidge functionality to show GraphView.
""" """
def setUp(self): def setUp(self):
...@@ -59,9 +59,8 @@ class test_show_gview(unittest.TestCase): ...@@ -59,9 +59,8 @@ class test_show_gview(unittest.TestCase):
def tearDown(self): def tearDown(self):
pass pass
def test_model_to_json(self): def test_gview_to_json(self):
gview = create_gview() gview = create_gview()
# Create temporary file to store JSON model description # Create temporary file to store JSON model description
...@@ -74,58 +73,55 @@ class test_show_gview(unittest.TestCase): ...@@ -74,58 +73,55 @@ class test_show_gview(unittest.TestCase):
model_json = json.load(fp) model_json = json.load(fp)
# Get list of nodes of Aidge graphview # Get list of nodes of Aidge graphview
gview_ranked_nodes = gview.get_ranked_nodes()[0] gview_ordered_nodes = gview.get_ordered_nodes()
# Iterate over ranked_nodes
self.assertEqual(len(gview_ranked_nodes), len(model_json['graph'])) # Iterate over the list of ordered nodes and the corresponding JSON
self.assertEqual(len(gview_ordered_nodes), len(model_json['graph']))
for node_gview in gview_ranked_nodes: for node_gview, node_json in zip(gview_ordered_nodes, model_json['graph']):
for node_json in model_json['graph']:
if node_gview.name() == node_json['name']:
self.assertEqual(node_gview.get_operator().type(), node_json['optype']) self.assertEqual(node_gview.get_operator().type(), node_json['optype'])
self.assertEqual(node_gview.get_operator().nb_inputs(), node_json['nb_inputs']) self.assertEqual(node_gview.get_operator().nb_inputs(), node_json['nb_inputs'])
self.assertEqual(node_gview.get_operator().nb_outputs(), node_json['nb_outputs']) self.assertEqual(node_gview.get_operator().nb_outputs(), node_json['nb_outputs'])
self.assertEqual(node_gview.get_operator().nb_inputs(), len(node_json['inputs'])) self.assertEqual(node_gview.get_operator().nb_inputs(), len(node_json['inputs']))
for input_idx in range(node_gview.get_operator().nb_inputs()): for input_idx in range(node_gview.get_operator().nb_inputs()):
self.assertEqual(node_gview.get_operator().get_input(input_idx).dims(), node_json['inputs'][input_idx]['dims']) self.assertEqual(node_gview.get_operator().get_input(input_idx).dims(), node_json['inputs'][input_idx]['dims'])
self.assertEqual(str(node_gview.get_operator().get_input(input_idx).dtype()), node_json['inputs'][input_idx]['data_type']) self.assertEqual(str(node_gview.get_operator().get_input(input_idx).dtype()), node_json['inputs'][input_idx]['data_type'])
self.assertEqual(str(node_gview.get_operator().get_input(input_idx).dformat()), node_json['inputs'][input_idx]['data_format']) self.assertEqual(str(node_gview.get_operator().get_input(input_idx).dformat()), node_json['inputs'][input_idx]['data_format'])
self.assertEqual(node_gview.get_operator().nb_outputs(), len(node_json['outputs'])) self.assertEqual(node_gview.get_operator().nb_outputs(), len(node_json['outputs']))
for output_idx in range(node_gview.get_operator().nb_outputs()): for output_idx in range(node_gview.get_operator().nb_outputs()):
self.assertEqual(node_gview.get_operator().get_output(output_idx).dims(), node_json['outputs'][output_idx]['dims']) self.assertEqual(node_gview.get_operator().get_output(output_idx).dims(), node_json['outputs'][output_idx]['dims'])
self.assertEqual(str(node_gview.get_operator().get_output(output_idx).dtype()), node_json['outputs'][output_idx]['data_type']) self.assertEqual(str(node_gview.get_operator().get_output(output_idx).dtype()), node_json['outputs'][output_idx]['data_type'])
self.assertEqual(str(node_gview.get_operator().get_output(output_idx).dformat()), node_json['outputs'][output_idx]['data_format']) self.assertEqual(str(node_gview.get_operator().get_output(output_idx).dformat()), node_json['outputs'][output_idx]['data_format'])
self.assertEqual(len(node_gview.get_parents()), len(node_json['parents'])) self.assertEqual(len(node_gview.get_parents()), len(node_json['parents']))
self.assertEqual(len(node_gview.get_children()), len(node_json['children'])) self.assertEqual(len(node_gview.get_children()), len(node_json['children']))
if not hasattr(node_gview.get_operator(), 'get_micro_graph'): if not hasattr(node_gview.get_operator(), 'get_micro_graph'):
try: try:
self.assertEqual(len(node_gview.get_operator().attr.dict()), len(node_json['attributes'])) self.assertEqual(len(node_gview.get_operator().attr.dict()), len(node_json['attributes']))
self.assertDictEqual(node_gview.get_operator().attr.dict(), node_json['attributes']) self.assertDictEqual(node_gview.get_operator().attr.dict(), node_json['attributes'])
except AttributeError: except AttributeError:
self.assertIsNone(node_gview.get_operator().attr) and self.assertFalse(node_json['attributes']) self.assertIsNone(node_gview.get_operator().attr) and self.assertFalse(node_json['attributes'])
elif hasattr(node_gview.get_operator(), 'get_micro_graph'): elif hasattr(node_gview.get_operator(), 'get_micro_graph'):
self.assertEqual(len(node_gview.get_operator().get_micro_graph().get_nodes()), len(node_json['attributes']['micro_graph'])) self.assertEqual(len(node_gview.get_operator().get_micro_graph().get_nodes()), len(node_json['attributes']['micro_graph']))
for micro_node_gview in node_gview.get_operator().get_micro_graph().get_nodes(): for micro_node_gview in node_gview.get_operator().get_micro_graph().get_nodes():
for micro_node_json in node_json['attributes']['micro_graph']: for micro_node_json in node_json['attributes']['micro_graph']:
if micro_node_gview.get_operator().type() == micro_node_json['optype']: if micro_node_gview.get_operator().type() == micro_node_json['optype']:
for key, value in micro_node_gview.get_operator().attr.dict().items(): for key, value in micro_node_gview.get_operator().attr.dict().items():
if not type(value).__name__ in dir(builtins): if not type(value).__name__ in dir(builtins):
# Replace original value by its name (str) because value is of a type that could not be written to the JSON # Replace original value by its name (str) because value is of a type that could not be written to the JSON
# Cannot update this dict inplace : micro_node_gview.get_operator().attr.dict().update({key : value.name}) # Cannot update this dict inplace : micro_node_gview.get_operator().attr.dict().update({key : value.name})
temp_mnode_dict = micro_node_gview.get_operator().attr.dict() temp_mnode_dict = micro_node_gview.get_operator().attr.dict()
temp_mnode_dict.update({key : value.name}) temp_mnode_dict.update({key : value.name})
self.assertDictEqual(temp_mnode_dict, micro_node_json['attributes']) self.assertDictEqual(temp_mnode_dict, micro_node_json['attributes'])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment