""" Testing object mode specifics. """ import numpy as np import unittest from numba import jit from numba.core import utils from numba.tests.support import TestCase def complex_constant(n): tmp = n + 4 return tmp + 3j def long_constant(n): return n + 100000000000000000000000000000000000000000000000 def delitem_usecase(x): del x[:] def loop_nest_3(x, y): n = 0 for i in range(x): for j in range(y): for k in range(x + y): n += i * j return n def array_of_object(x): return x class TestObjectMode(TestCase): def test_complex_constant(self): pyfunc = complex_constant cfunc = jit((), forceobj=True)(pyfunc) self.assertPreciseEqual(pyfunc(12), cfunc(12)) def test_long_constant(self): pyfunc = long_constant cfunc = jit((), forceobj=True)(pyfunc) self.assertPreciseEqual(pyfunc(12), cfunc(12)) def test_loop_nest(self): """ Test bug that decref the iterator early. If the bug occurs, a segfault should occur """ pyfunc = loop_nest_3 cfunc = jit((), forceobj=True)(pyfunc) self.assertEqual(pyfunc(5, 5), cfunc(5, 5)) def bm_pyfunc(): pyfunc(5, 5) def bm_cfunc(): cfunc(5, 5) utils.benchmark(bm_pyfunc) utils.benchmark(bm_cfunc) def test_array_of_object(self): cfunc = jit(forceobj=True)(array_of_object) objarr = np.array([object()] * 10) self.assertIs(cfunc(objarr), objarr) def test_sequence_contains(self): """ Test handling of the `in` comparison """ @jit(forceobj=True) def foo(x, y): return x in y self.assertTrue(foo(1, [0, 1])) self.assertTrue(foo(0, [0, 1])) self.assertFalse(foo(2, [0, 1])) with self.assertRaises(TypeError) as raises: foo(None, None) self.assertIn("is not iterable", str(raises.exception)) def test_delitem(self): pyfunc = delitem_usecase cfunc = jit((), forceobj=True)(pyfunc) l = [3, 4, 5] cfunc(l) self.assertPreciseEqual(l, []) with self.assertRaises(TypeError): cfunc(42) def test_starargs_non_tuple(self): def consumer(*x): return x @jit(forceobj=True) def foo(x): return consumer(*x) arg = "ijo" got = foo(arg) expect = foo.py_func(arg) self.assertEqual(got, tuple(arg)) self.assertEqual(got, expect) def test_expr_undef(self): @jit(forceobj=True) def foo(): # In Py3.12, this will emit a Expr.undef. return [x for x in (1, 2)] self.assertEqual(foo(), foo.py_func()) class TestObjectModeInvalidRewrite(TestCase): """ Tests to ensure that rewrite passes didn't affect objmode lowering. """ def _ensure_objmode(self, disp): self.assertTrue(disp.signatures) self.assertFalse(disp.nopython_signatures) return disp def test_static_raise_in_objmode_fallback(self): """ Test code based on user submitted issue at https://github.com/numba/numba/issues/2159 """ def test0(n): return n def test1(n): if n == 0: # static raise will fail in objmode if the IR is modified by # rewrite pass raise ValueError() return test0(n) # trigger objmode fallback compiled = jit(forceobj=True)(test1) self.assertEqual(test1(10), compiled(10)) self._ensure_objmode(compiled) def test_static_setitem_in_objmode_fallback(self): """ Test code based on user submitted issue at https://github.com/numba/numba/issues/2169 """ def test0(n): return n def test(a1, a2): a1 = np.asarray(a1) # static setitem here will fail in objmode if the IR is modified by # rewrite pass a2[0] = 1 return test0(a1.sum() + a2.sum()) # trigger objmode fallback compiled = jit(forceobj=True)(test) args = np.array([3]), np.array([4]) self.assertEqual(test(*args), compiled(*args)) self._ensure_objmode(compiled) def test_dynamic_func_objmode(self): """ Test issue https://github.com/numba/numba/issues/3355 """ func_text = "def func():\n" func_text += " np.array([1,2,3])\n" loc_vars = {} custom_globals = {'np': np} exec(func_text, custom_globals, loc_vars) func = loc_vars['func'] jitted = jit(forceobj=True)(func) jitted() if __name__ == '__main__': unittest.main()