Skip to content
Snippets Groups Projects

Remove the consumer/producer system from the forward method.

Merged Cyril Moineau requested to merge consumerProducerRefactor into master
21 files
+ 146
82
Compare changes
  • Side-by-side
  • Inline
Files
21
@@ -36,7 +36,54 @@ class test_scheduler(unittest.TestCase):
for i in range(len(expected_out)):
self.assertEqual(expected_out[i], out_tensor[i])
def test_sequential_scheduling(self):
input_data = np.array([]).astype(np.float32)
input_tensor = aidge_core.Tensor(input_data)
input_node = aidge_core.Producer(input_tensor, "X")
graph_view = aidge_core.sequential([
aidge_core.FC(50, name='0'),
aidge_core.FC(50, name='1'),
aidge_core.FC(10, name='2'),
])
EXPECTED_SCHEDULE = ['0', '1', '2']
input_node.add_child(graph_view)
input_node.get_operator().set_datatype(aidge_core.DataType.Float32)
input_node.get_operator().set_backend("cpu")
graph_view.set_datatype(aidge_core.DataType.Float32)
graph_view.set_backend("cpu")
scheduler = aidge_core.SequentialScheduler(graph_view)
scheduler.generate_scheduling()
self.assertListEqual([i.name() for i in scheduler.get_static_scheduling()], EXPECTED_SCHEDULE)
def test_parallel_scheduling(self):
input_data = np.array([]).astype(np.float32)
input_tensor = aidge_core.Tensor(input_data)
input_node = aidge_core.Producer(input_tensor, "X")
graph_view = aidge_core.sequential([
aidge_core.FC(50, name='0'),
aidge_core.parallel([aidge_core.FC(50, name='1'), aidge_core.FC(50, name='3')]),
aidge_core.Add(name='2'),
])
EXPECTED_SCHEDULE = [['0', '1', '3', '2'], ['0', '3', '1', '2']] # Both scheduling are valid !
input_node.add_child(graph_view)
input_node.get_operator().set_datatype(aidge_core.DataType.Float32)
input_node.get_operator().set_backend("cpu")
graph_view.set_datatype(aidge_core.DataType.Float32)
graph_view.set_backend("cpu")
scheduler = aidge_core.SequentialScheduler(graph_view)
scheduler.generate_scheduling()
self.assertTrue([i.name() for i in scheduler.get_static_scheduling()] in EXPECTED_SCHEDULE)
if __name__ == '__main__':
unittest.main()
Loading