From abbcc9994a4446030ef63f512f98e452249438e4 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Thu, 29 Feb 2024 14:27:34 +0000 Subject: [PATCH] Add unit test for python registrar system. --- aidge_core/unit_tests/test_impl.py | 72 ++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 aidge_core/unit_tests/test_impl.py diff --git a/aidge_core/unit_tests/test_impl.py b/aidge_core/unit_tests/test_impl.py new file mode 100644 index 000000000..ad7ee666e --- /dev/null +++ b/aidge_core/unit_tests/test_impl.py @@ -0,0 +1,72 @@ +""" +Copyright (c) 2023 CEA-List + +This program and the accompanying materials are made available under the +terms of the Eclipse Public License 2.0 which is available at +http://www.eclipse.org/legal/epl-2.0. + +SPDX-License-Identifier: EPL-2.0 +""" + +import unittest +import aidge_core +from functools import reduce + +import numpy as np + +GLOBAL_CPT = 0 + +class testImpl(aidge_core.OperatorImpl): + def __init__(self, op: aidge_core.Operator): + aidge_core.OperatorImpl.__init__(self, op) # Required to avoid type error ! + + def forward(self): + global GLOBAL_CPT + GLOBAL_CPT += 1 + +class test_OperatorImpl(unittest.TestCase): + """Test Op + """ + def setUp(self): + global GLOBAL_CPT + GLOBAL_CPT = 0 + def tearDown(self): + pass + + def test_setImplementation(self): + """Test setting an implementation manually + """ + global GLOBAL_CPT + matmul = aidge_core.GenericOperator("MatMul", 1, 0, 1, name="MatMul0") + generic_matmul_op = matmul.get_operator() + generic_matmul_op.set_compute_output_dims(lambda x: x) + generic_matmul_op.set_impl(testImpl(generic_matmul_op)) + generic_matmul_op.forward() + self.assertEqual(GLOBAL_CPT, 1) + + def test_Registrar_setOp(self): + """Test registering an implementation + """ + global GLOBAL_CPT + aidge_core.register_ConvOp2D("cpu", testImpl) + self.assertTrue("cpu" in aidge_core.get_keys_ConvOp2D()) + conv = aidge_core.Conv2D(2,2,[1,1], name="Conv0") + conv.get_operator().set_backend("cpu") + conv.get_operator().forward() + self.assertEqual(GLOBAL_CPT, 1) + + def test_Registrar_setGraphView(self): + """Test registering an implementation + """ + global GLOBAL_CPT + aidge_core.register_ConvOp2D("cpu", testImpl) + aidge_core.register_ProducerOp("cpu", testImpl) + self.assertTrue("cpu" in aidge_core.get_keys_ConvOp2D()) + conv = aidge_core.Conv2D(2,2,[1,1], name="Conv0") + model = aidge_core.sequential([conv]) + model.set_backend("cpu") + conv.get_operator().forward() + self.assertEqual(GLOBAL_CPT, 1) + +if __name__ == '__main__': + unittest.main() -- GitLab