Skip to content
Snippets Groups Projects

Remove the consumer/producer system from the forward method.

Merged Cyril Moineau requested to merge consumerProducerRefactor into master
1 file
+ 47
0
Compare changes
  • Side-by-side
  • Inline
@@ -36,7 +36,54 @@ class test_scheduler(unittest.TestCase):
@@ -36,7 +36,54 @@ class test_scheduler(unittest.TestCase):
for i in range(len(expected_out)):
for i in range(len(expected_out)):
self.assertEqual(expected_out[i], out_tensor[i])
self.assertEqual(expected_out[i], out_tensor[i])
 
def test_sequential_scheduling(self):
 
input_data = np.array([]).astype(np.float32)
 
input_tensor = aidge_core.Tensor(input_data)
 
input_node = aidge_core.Producer(input_tensor, "X")
 
 
graph_view = aidge_core.sequential([
 
aidge_core.FC(50, name='0'),
 
aidge_core.FC(50, name='1'),
 
aidge_core.FC(10, name='2'),
 
])
 
EXPECTED_SCHEDULE = ['0', '1', '2']
 
 
input_node.add_child(graph_view)
 
input_node.get_operator().set_datatype(aidge_core.DataType.Float32)
 
input_node.get_operator().set_backend("cpu")
 
graph_view.set_datatype(aidge_core.DataType.Float32)
 
graph_view.set_backend("cpu")
 
 
scheduler = aidge_core.SequentialScheduler(graph_view)
 
scheduler.generate_scheduling()
 
 
self.assertListEqual([i.name() for i in scheduler.get_static_scheduling()], EXPECTED_SCHEDULE)
 
 
 
def test_parallel_scheduling(self):
 
input_data = np.array([]).astype(np.float32)
 
input_tensor = aidge_core.Tensor(input_data)
 
 
input_node = aidge_core.Producer(input_tensor, "X")
 
graph_view = aidge_core.sequential([
 
aidge_core.FC(50, name='0'),
 
aidge_core.parallel([aidge_core.FC(50, name='1'), aidge_core.FC(50, name='3')]),
 
aidge_core.Add(name='2'),
 
])
 
 
EXPECTED_SCHEDULE = [['0', '1', '3', '2'], ['0', '3', '1', '2']] # Both scheduling are valid !
 
 
input_node.add_child(graph_view)
 
input_node.get_operator().set_datatype(aidge_core.DataType.Float32)
 
input_node.get_operator().set_backend("cpu")
 
graph_view.set_datatype(aidge_core.DataType.Float32)
 
graph_view.set_backend("cpu")
 
 
scheduler = aidge_core.SequentialScheduler(graph_view)
 
scheduler.generate_scheduling()
 
 
self.assertTrue([i.name() for i in scheduler.get_static_scheduling()] in EXPECTED_SCHEDULE)
if __name__ == '__main__':
if __name__ == '__main__':
unittest.main()
unittest.main()
Loading