Skip to content
Snippets Groups Projects
Commit 4ff3d6a8 authored by Iryna de Albuquerque Silva's avatar Iryna de Albuquerque Silva Committed by Iryna de Albuquerque Silva
Browse files

Update unit tests to handle aidge_core.Tensor objects.

parent acfb2330
No related branches found
No related tags found
2 merge requests!414Update version 0.5.1 -> 0.6.0,!411Fix attributes dictionary handling in show_graphview.py
Pipeline #71071 passed
...@@ -3,6 +3,7 @@ import tempfile ...@@ -3,6 +3,7 @@ import tempfile
import unittest import unittest
import builtins import builtins
import aidge_core import aidge_core
import numpy as np
from pathlib import Path from pathlib import Path
from aidge_core.show_graphview import gview_to_json, str_aidge_graph_structure, str_aidge_seq_scheduling from aidge_core.show_graphview import gview_to_json, str_aidge_graph_structure, str_aidge_seq_scheduling
...@@ -105,7 +106,26 @@ class test_show_gview(unittest.TestCase): ...@@ -105,7 +106,26 @@ class test_show_gview(unittest.TestCase):
if not hasattr(node_gview.get_operator(), 'get_micro_graph'): if not hasattr(node_gview.get_operator(), 'get_micro_graph'):
try: try:
self.assertEqual(len(node_gview.get_operator().attr.dict()), len(node_json['attributes'])) self.assertEqual(len(node_gview.get_operator().attr.dict()), len(node_json['attributes']))
self.assertDictEqual(node_gview.get_operator().attr.dict(), node_json['attributes'])
temp_node_dict = node_gview.get_operator().attr.dict()
for key, value in node_gview.get_operator().attr.dict().items():
if isinstance(value, aidge_core.aidge_core.Tensor):
new_value = {
"dims": value.dims(),
"data_type": value.dtype(),
"tensor_data": np.array(value).tolist()
}
temp_node_dict.update({key : new_value})
elif not type(value).__name__ in dir(builtins):
temp_node_dict.update({key : str(value)})
else:
pass
self.assertDictEqual(temp_node_dict, node_json['attributes'])
except AttributeError: except AttributeError:
self.assertIsNone(node_gview.get_operator().attr) and self.assertFalse(node_json['attributes']) self.assertIsNone(node_gview.get_operator().attr) and self.assertFalse(node_json['attributes'])
...@@ -117,14 +137,24 @@ class test_show_gview(unittest.TestCase): ...@@ -117,14 +137,24 @@ class test_show_gview(unittest.TestCase):
for micro_node_gview in node_gview.get_operator().get_micro_graph().get_nodes(): for micro_node_gview in node_gview.get_operator().get_micro_graph().get_nodes():
for micro_node_json in node_json['attributes']['micro_graph']: for micro_node_json in node_json['attributes']['micro_graph']:
if micro_node_gview.get_operator().type() == micro_node_json['optype']: if micro_node_gview.get_operator().type() == micro_node_json['optype']:
temp_mnode_dict = micro_node_gview.get_operator().attr.dict() # So the dict can be updated if needed
for key, value in micro_node_gview.get_operator().attr.dict().items(): for key, value in micro_node_gview.get_operator().attr.dict().items():
if not type(value).__name__ in dir(builtins): if isinstance(value, aidge_core.aidge_core.Tensor):
# Replace original value by its name (str) because value is of a type that could not be written to the JSON new_value = {
# Cannot update this dict inplace : micro_node_gview.get_operator().attr.dict().update({key : value.name}) "dims": value.dims(),
temp_mnode_dict = micro_node_gview.get_operator().attr.dict() "data_type": str(value.dtype()),
temp_mnode_dict.update({key : value.name}) "tensor_data": np.array(value).tolist()
self.assertDictEqual(temp_mnode_dict, micro_node_json['attributes']) }
temp_mnode_dict.update({key : new_value})
elif not type(value).__name__ in dir(builtins):
# Use str(value) to stay consistent with how json.dumps(..., default=str) handles custom objects
temp_mnode_dict.update({key : str(value)})
else:
pass
self.assertDictEqual(temp_mnode_dict, micro_node_json['attributes'])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment