Skip to content
Snippets Groups Projects
Commit 19398e4a authored by Iryna de Albuquerque Silva's avatar Iryna de Albuquerque Silva
Browse files

Replaced local function for topological ordering by new GraphView's get_ordered_nodes() method

parent c6654d1e
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!211Add show_graphview funcionality.
......@@ -4,73 +4,11 @@ import builtins
import aidge_core
import numpy as np
def dfs(graph : aidge_core.GraphView, node : aidge_core.Node, visited_nodes : list[aidge_core.Node], sorted_nodes : list[aidge_core.Node]) -> None:
"""
Performs the depth-first search algorithm for topological sorting.
:param graph: An unsorted GraphView of Aidge.
:type graph: aidge_core.GraphView
:param node: The GraphView's Node that is being treated.
:type node: aidge_core.Node
:param visited_nodes: List of nodes that have already been visited.
:type visited_nodes: list[aidge_core.Node]
:param sorted_nodes: List of nodes that have already been sorted.
:type sorted_nodes: list[aidge_core.Node]
"""
node_children = node.get_children()
visited_nodes.add(node)
for child in node_children:
if child not in visited_nodes:
dfs(graph, child, visited_nodes, sorted_nodes)
sorted_nodes.append(node)
# Make sure Producers are treated:
parents = []
for parent in node.get_parents():
try:
has_parents = parent.get_parents()
except AttributeError:
has_parents = False
if (not has_parents) and (parent not in sorted_nodes):
parents.append(parent)
parents.reverse()
sorted_nodes.extend(parents)
return None
def topological_sort(graph : aidge_core.GraphView) -> list[aidge_core.Node]:
"""
Performs topological sorting by applying depth-first search algorithm recursively.
:param graph: An unsorted GraphView of Aidge.
:type graph: aidge_core.GraphView
:return: A list with the GraphView's sorted nodes.
:rtype: list[aidge_core.Node]
"""
input_nodes = graph.get_input_nodes()
visited_nodes = set()
sorted_nodes = []
for input in input_nodes:
if input not in visited_nodes:
dfs(input_nodes, input, visited_nodes, sorted_nodes)
sorted_nodes.reverse()
return sorted_nodes
def retrieve_operator_attrs(node : aidge_core.Node) -> dict[str, int, float, bool, None]:
def _retrieve_operator_attrs(node : aidge_core.Node) -> dict[str, int, float, bool, None]:
"""
Returns the dictionary containing the attributes of a given Node.
:param graph: A Node in the list of sorted nodes.
:param graph: A Node in the list of ordered nodes.
:type graph: aidge_core.Node
:return: A dictionary with the Node's attributes.
:rtype: dict[str, int, float, bool, None]
......@@ -86,11 +24,11 @@ def retrieve_operator_attrs(node : aidge_core.Node) -> dict[str, int, float, boo
return node_attr_dict
def create_dict(sorted_nodes : list[aidge_core.Node], write_trainable_params_ext : bool, write_trainable_params_embed : bool, params_file_format : str, path_trainable_params : str) -> dict[str, int, float, bool, None]:
def _create_dict(ordered_nodes : list[aidge_core.Node], write_trainable_params_ext : bool, write_trainable_params_embed : bool, params_file_format : str, path_trainable_params : str) -> dict[str, int, float, bool, None]:
"""
Creates a dictionary to store the information of a given sorted GraphView.
Creates a dictionary to store the information of a given ordered GraphView.
:param sorted_nodes: A list with the GraphView's sorted nodes.
:param ordered_nodes: A list with the GraphView's ordered nodes.
:type graph: list
:param write_trainable_params_ext: Whether or not to write the eventual trainable parameters of the Nodes in an external file.
:type write_trainable_params_ext: bool
......@@ -107,7 +45,7 @@ def create_dict(sorted_nodes : list[aidge_core.Node], write_trainable_params_ext
graphview_dict = {'graph': []}
for node in sorted_nodes:
for node in ordered_nodes:
if node is not None:
node_dict = {'name' : node.name(),
......@@ -171,12 +109,12 @@ def create_dict(sorted_nodes : list[aidge_core.Node], write_trainable_params_ext
micro_node_dict = {'name' : micro_node.name(),
'optype' : micro_node.type()}
micro_node_attr_dict = retrieve_operator_attrs(micro_node)
micro_node_attr_dict = _retrieve_operator_attrs(micro_node)
micro_node_dict['attributes'] = micro_node_attr_dict
attributes_dict['micro_graph'].append(micro_node_dict)
else:
node_attr_dict = retrieve_operator_attrs(node)
node_attr_dict = _retrieve_operator_attrs(node)
attributes_dict.update(node_attr_dict)
node_dict['attributes'] = attributes_dict
......@@ -215,7 +153,7 @@ def create_dict(sorted_nodes : list[aidge_core.Node], write_trainable_params_ext
return graphview_dict
def write_dict_json(graphview_dict : dict[str, int, float, bool, None], json_path : str) -> None:
def _write_dict_json(graphview_dict : dict[str, int, float, bool, None], json_path : str) -> None:
"""
Writes dictionary containing GraphView description to a JSON file.
......@@ -255,13 +193,13 @@ def gview_to_json(gview : aidge_core.GraphView, json_path : str, write_trainable
else:
path_trainable_params = ''
# Sort graphview
sorted_nodes = topological_sort(gview)
# Sort GraphView in topological order
ordered_nodes = gview.get_ordered_nodes()
# Create dict from graphview
graphview_dict = create_dict(sorted_nodes, write_trainable_params_ext, write_trainable_params_embed, params_file_format, path_trainable_params)
# Create dict from GraphView
graphview_dict = _create_dict(ordered_nodes, write_trainable_params_ext, write_trainable_params_embed, params_file_format, path_trainable_params)
# Write dict to json
write_dict_json(graphview_dict, json_path)
# Write dict to JSON
_write_dict_json(graphview_dict, json_path)
return None
\ No newline at end of file
......@@ -50,7 +50,7 @@ def create_gview():
return gview
class test_show_gview(unittest.TestCase):
"""Test aidge show GraphView
"""Test aidge functionality to show GraphView.
"""
def setUp(self):
......@@ -59,9 +59,8 @@ class test_show_gview(unittest.TestCase):
def tearDown(self):
pass
def test_model_to_json(self):
def test_gview_to_json(self):
gview = create_gview()
# Create temporary file to store JSON model description
......@@ -74,58 +73,55 @@ class test_show_gview(unittest.TestCase):
model_json = json.load(fp)
# Get list of nodes of Aidge graphview
gview_ranked_nodes = gview.get_ranked_nodes()[0]
# Iterate over ranked_nodes
gview_ordered_nodes = gview.get_ordered_nodes()
self.assertEqual(len(gview_ranked_nodes), len(model_json['graph']))
# Iterate over the list of ordered nodes and the corresponding JSON
self.assertEqual(len(gview_ordered_nodes), len(model_json['graph']))
for node_gview in gview_ranked_nodes:
for node_json in model_json['graph']:
if node_gview.name() == node_json['name']:
for node_gview, node_json in zip(gview_ordered_nodes, model_json['graph']):
self.assertEqual(node_gview.get_operator().type(), node_json['optype'])
self.assertEqual(node_gview.get_operator().nb_inputs(), node_json['nb_inputs'])
self.assertEqual(node_gview.get_operator().nb_outputs(), node_json['nb_outputs'])
self.assertEqual(node_gview.get_operator().nb_inputs(), len(node_json['inputs']))
for input_idx in range(node_gview.get_operator().nb_inputs()):
self.assertEqual(node_gview.get_operator().get_input(input_idx).dims(), node_json['inputs'][input_idx]['dims'])
self.assertEqual(str(node_gview.get_operator().get_input(input_idx).dtype()), node_json['inputs'][input_idx]['data_type'])
self.assertEqual(str(node_gview.get_operator().get_input(input_idx).dformat()), node_json['inputs'][input_idx]['data_format'])
self.assertEqual(node_gview.get_operator().nb_outputs(), len(node_json['outputs']))
for output_idx in range(node_gview.get_operator().nb_outputs()):
self.assertEqual(node_gview.get_operator().get_output(output_idx).dims(), node_json['outputs'][output_idx]['dims'])
self.assertEqual(str(node_gview.get_operator().get_output(output_idx).dtype()), node_json['outputs'][output_idx]['data_type'])
self.assertEqual(str(node_gview.get_operator().get_output(output_idx).dformat()), node_json['outputs'][output_idx]['data_format'])
self.assertEqual(len(node_gview.get_parents()), len(node_json['parents']))
self.assertEqual(len(node_gview.get_children()), len(node_json['children']))
if not hasattr(node_gview.get_operator(), 'get_micro_graph'):
try:
self.assertEqual(len(node_gview.get_operator().attr.dict()), len(node_json['attributes']))
self.assertDictEqual(node_gview.get_operator().attr.dict(), node_json['attributes'])
except AttributeError:
self.assertIsNone(node_gview.get_operator().attr) and self.assertFalse(node_json['attributes'])
elif hasattr(node_gview.get_operator(), 'get_micro_graph'):
self.assertEqual(len(node_gview.get_operator().get_micro_graph().get_nodes()), len(node_json['attributes']['micro_graph']))
for micro_node_gview in node_gview.get_operator().get_micro_graph().get_nodes():
for micro_node_json in node_json['attributes']['micro_graph']:
if micro_node_gview.get_operator().type() == micro_node_json['optype']:
for key, value in micro_node_gview.get_operator().attr.dict().items():
if not type(value).__name__ in dir(builtins):
# Replace original value by its name (str) because value is of a type that could not be written to the JSON
# Cannot update this dict inplace : micro_node_gview.get_operator().attr.dict().update({key : value.name})
temp_mnode_dict = micro_node_gview.get_operator().attr.dict()
temp_mnode_dict.update({key : value.name})
self.assertDictEqual(temp_mnode_dict, micro_node_json['attributes'])
self.assertEqual(node_gview.get_operator().type(), node_json['optype'])
self.assertEqual(node_gview.get_operator().nb_inputs(), node_json['nb_inputs'])
self.assertEqual(node_gview.get_operator().nb_outputs(), node_json['nb_outputs'])
self.assertEqual(node_gview.get_operator().nb_inputs(), len(node_json['inputs']))
for input_idx in range(node_gview.get_operator().nb_inputs()):
self.assertEqual(node_gview.get_operator().get_input(input_idx).dims(), node_json['inputs'][input_idx]['dims'])
self.assertEqual(str(node_gview.get_operator().get_input(input_idx).dtype()), node_json['inputs'][input_idx]['data_type'])
self.assertEqual(str(node_gview.get_operator().get_input(input_idx).dformat()), node_json['inputs'][input_idx]['data_format'])
self.assertEqual(node_gview.get_operator().nb_outputs(), len(node_json['outputs']))
for output_idx in range(node_gview.get_operator().nb_outputs()):
self.assertEqual(node_gview.get_operator().get_output(output_idx).dims(), node_json['outputs'][output_idx]['dims'])
self.assertEqual(str(node_gview.get_operator().get_output(output_idx).dtype()), node_json['outputs'][output_idx]['data_type'])
self.assertEqual(str(node_gview.get_operator().get_output(output_idx).dformat()), node_json['outputs'][output_idx]['data_format'])
self.assertEqual(len(node_gview.get_parents()), len(node_json['parents']))
self.assertEqual(len(node_gview.get_children()), len(node_json['children']))
if not hasattr(node_gview.get_operator(), 'get_micro_graph'):
try:
self.assertEqual(len(node_gview.get_operator().attr.dict()), len(node_json['attributes']))
self.assertDictEqual(node_gview.get_operator().attr.dict(), node_json['attributes'])
except AttributeError:
self.assertIsNone(node_gview.get_operator().attr) and self.assertFalse(node_json['attributes'])
elif hasattr(node_gview.get_operator(), 'get_micro_graph'):
self.assertEqual(len(node_gview.get_operator().get_micro_graph().get_nodes()), len(node_json['attributes']['micro_graph']))
for micro_node_gview in node_gview.get_operator().get_micro_graph().get_nodes():
for micro_node_json in node_json['attributes']['micro_graph']:
if micro_node_gview.get_operator().type() == micro_node_json['optype']:
for key, value in micro_node_gview.get_operator().attr.dict().items():
if not type(value).__name__ in dir(builtins):
# Replace original value by its name (str) because value is of a type that could not be written to the JSON
# Cannot update this dict inplace : micro_node_gview.get_operator().attr.dict().update({key : value.name})
temp_mnode_dict = micro_node_gview.get_operator().attr.dict()
temp_mnode_dict.update({key : value.name})
self.assertDictEqual(temp_mnode_dict, micro_node_json['attributes'])
if __name__ == '__main__':
unittest.main()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment