diff --git a/aidge_core/unit_tests/test_operator_binding.py b/aidge_core/unit_tests/test_operator_binding.py index fc60f52274162155f8f891bf86c22c9a13b241f4..c7279afed2aed00981d0b15002b1676abcaef72e 100644 --- a/aidge_core/unit_tests/test_operator_binding.py +++ b/aidge_core/unit_tests/test_operator_binding.py @@ -102,5 +102,30 @@ class test_operator_binding(unittest.TestCase): genOp.get_operator().compute_output_dims() self.assertListEqual(genOp.get_operator().output(0).dims(), in_dims) + def test_set_impl(self): + + class PythonCustomImpl(aidge_core.OperatorImpl): + """Dummy implementation to test that C++ call python code + """ + def __init__(self): + aidge_core.OperatorImpl.__init__(self) # Recquired to avoid type error ! + self.idx = 0 + + def forward(self): + """Increment idx attribute on forward. + """ + self.idx += 1 + + generic_node = aidge_core.GenericOperator("Relu", 1, 1, 1, name="myReLu") + customImpl = PythonCustomImpl() + generic_op = generic_node.get_operator() + + generic_op.forward() # Do nothing, no implementation set + generic_op.set_impl(customImpl) + generic_op.forward() # Increment idx + self.assertEqual(customImpl.idx, 1) + + + if __name__ == '__main__': unittest.main()