From abbcc9994a4446030ef63f512f98e452249438e4 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Thu, 29 Feb 2024 14:27:34 +0000
Subject: [PATCH] Add unit test for python registrar system.

---
 aidge_core/unit_tests/test_impl.py | 72 ++++++++++++++++++++++++++++++
 1 file changed, 72 insertions(+)
 create mode 100644 aidge_core/unit_tests/test_impl.py

diff --git a/aidge_core/unit_tests/test_impl.py b/aidge_core/unit_tests/test_impl.py
new file mode 100644
index 000000000..ad7ee666e
--- /dev/null
+++ b/aidge_core/unit_tests/test_impl.py
@@ -0,0 +1,72 @@
+"""
+Copyright (c) 2023 CEA-List
+
+This program and the accompanying materials are made available under the
+terms of the Eclipse Public License 2.0 which is available at
+http://www.eclipse.org/legal/epl-2.0.
+
+SPDX-License-Identifier: EPL-2.0
+"""
+
+import unittest
+import aidge_core
+from functools import reduce
+
+import numpy as np
+
+GLOBAL_CPT = 0
+
+class testImpl(aidge_core.OperatorImpl):
+    def __init__(self, op: aidge_core.Operator):
+        aidge_core.OperatorImpl.__init__(self, op) # Required to avoid type error !
+
+    def forward(self):
+        global GLOBAL_CPT
+        GLOBAL_CPT += 1
+
+class test_OperatorImpl(unittest.TestCase):
+    """Test Op
+    """
+    def setUp(self):
+        global GLOBAL_CPT
+        GLOBAL_CPT = 0
+    def tearDown(self):
+        pass
+
+    def test_setImplementation(self):
+        """Test setting an implementation manually
+        """
+        global GLOBAL_CPT
+        matmul = aidge_core.GenericOperator("MatMul", 1, 0, 1, name="MatMul0")
+        generic_matmul_op = matmul.get_operator()
+        generic_matmul_op.set_compute_output_dims(lambda x: x)
+        generic_matmul_op.set_impl(testImpl(generic_matmul_op))
+        generic_matmul_op.forward()
+        self.assertEqual(GLOBAL_CPT, 1)
+
+    def test_Registrar_setOp(self):
+        """Test registering an implementation
+        """
+        global GLOBAL_CPT
+        aidge_core.register_ConvOp2D("cpu", testImpl)
+        self.assertTrue("cpu" in aidge_core.get_keys_ConvOp2D())
+        conv = aidge_core.Conv2D(2,2,[1,1], name="Conv0")
+        conv.get_operator().set_backend("cpu")
+        conv.get_operator().forward()
+        self.assertEqual(GLOBAL_CPT, 1)
+
+    def test_Registrar_setGraphView(self):
+        """Test registering an implementation
+        """
+        global GLOBAL_CPT
+        aidge_core.register_ConvOp2D("cpu", testImpl)
+        aidge_core.register_ProducerOp("cpu", testImpl)
+        self.assertTrue("cpu" in aidge_core.get_keys_ConvOp2D())
+        conv = aidge_core.Conv2D(2,2,[1,1], name="Conv0")
+        model = aidge_core.sequential([conv])
+        model.set_backend("cpu")
+        conv.get_operator().forward()
+        self.assertEqual(GLOBAL_CPT, 1)
+
+if __name__ == '__main__':
+    unittest.main()
-- 
GitLab