From 871215bb9d97024454b68457c36b36963c5a3d7a Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Thu, 7 Sep 2023 11:48:35 +0000 Subject: [PATCH] [Unittest] Add sequential and parallel scheduling test. --- .../unit_tests/test_scheduler.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/aidge_backend_cpu/unit_tests/test_scheduler.py b/aidge_backend_cpu/unit_tests/test_scheduler.py index bc766203..d8cf3e16 100644 --- a/aidge_backend_cpu/unit_tests/test_scheduler.py +++ b/aidge_backend_cpu/unit_tests/test_scheduler.py @@ -36,7 +36,54 @@ class test_scheduler(unittest.TestCase): for i in range(len(expected_out)): self.assertEqual(expected_out[i], out_tensor[i]) + def test_sequential_scheduling(self): + input_data = np.array([]).astype(np.float32) + input_tensor = aidge_core.Tensor(input_data) + input_node = aidge_core.Producer(input_tensor, "X") + + graph_view = aidge_core.sequential([ + aidge_core.FC(50, name='0'), + aidge_core.FC(50, name='1'), + aidge_core.FC(10, name='2'), + ]) + EXPECTED_SCHEDULE = ['0', '1', '2'] + + input_node.add_child(graph_view) + input_node.get_operator().set_datatype(aidge_core.DataType.Float32) + input_node.get_operator().set_backend("cpu") + graph_view.set_datatype(aidge_core.DataType.Float32) + graph_view.set_backend("cpu") + + scheduler = aidge_core.SequentialScheduler(graph_view) + scheduler.generate_scheduling() + + self.assertListEqual([i.name() for i in scheduler.get_static_scheduling()], EXPECTED_SCHEDULE) + + + def test_parallel_scheduling(self): + input_data = np.array([]).astype(np.float32) + input_tensor = aidge_core.Tensor(input_data) + + input_node = aidge_core.Producer(input_tensor, "X") + graph_view = aidge_core.sequential([ + aidge_core.FC(50, name='0'), + aidge_core.parallel([aidge_core.FC(50, name='1'), aidge_core.FC(50, name='3')]), + aidge_core.Add(name='2'), + ]) + + EXPECTED_SCHEDULE = [['0', '1', '3', '2'], ['0', '3', '1', '2']] # Both scheduling are valid ! + + input_node.add_child(graph_view) + input_node.get_operator().set_datatype(aidge_core.DataType.Float32) + input_node.get_operator().set_backend("cpu") + graph_view.set_datatype(aidge_core.DataType.Float32) + graph_view.set_backend("cpu") + + scheduler = aidge_core.SequentialScheduler(graph_view) + scheduler.generate_scheduling() + + self.assertTrue([i.name() for i in scheduler.get_static_scheduling()] in EXPECTED_SCHEDULE) if __name__ == '__main__': unittest.main() -- GitLab