ai-content-maker/.venv/Lib/site-packages/numba/tests/test_object_mode.py

192 lines
4.7 KiB
Python

"""
Testing object mode specifics.
"""
import numpy as np
import unittest
from numba import jit
from numba.core import utils
from numba.tests.support import TestCase
def complex_constant(n):
tmp = n + 4
return tmp + 3j
def long_constant(n):
return n + 100000000000000000000000000000000000000000000000
def delitem_usecase(x):
del x[:]
def loop_nest_3(x, y):
n = 0
for i in range(x):
for j in range(y):
for k in range(x + y):
n += i * j
return n
def array_of_object(x):
return x
class TestObjectMode(TestCase):
def test_complex_constant(self):
pyfunc = complex_constant
cfunc = jit((), forceobj=True)(pyfunc)
self.assertPreciseEqual(pyfunc(12), cfunc(12))
def test_long_constant(self):
pyfunc = long_constant
cfunc = jit((), forceobj=True)(pyfunc)
self.assertPreciseEqual(pyfunc(12), cfunc(12))
def test_loop_nest(self):
"""
Test bug that decref the iterator early.
If the bug occurs, a segfault should occur
"""
pyfunc = loop_nest_3
cfunc = jit((), forceobj=True)(pyfunc)
self.assertEqual(pyfunc(5, 5), cfunc(5, 5))
def bm_pyfunc():
pyfunc(5, 5)
def bm_cfunc():
cfunc(5, 5)
utils.benchmark(bm_pyfunc)
utils.benchmark(bm_cfunc)
def test_array_of_object(self):
cfunc = jit(forceobj=True)(array_of_object)
objarr = np.array([object()] * 10)
self.assertIs(cfunc(objarr), objarr)
def test_sequence_contains(self):
"""
Test handling of the `in` comparison
"""
@jit(forceobj=True)
def foo(x, y):
return x in y
self.assertTrue(foo(1, [0, 1]))
self.assertTrue(foo(0, [0, 1]))
self.assertFalse(foo(2, [0, 1]))
with self.assertRaises(TypeError) as raises:
foo(None, None)
self.assertIn("is not iterable", str(raises.exception))
def test_delitem(self):
pyfunc = delitem_usecase
cfunc = jit((), forceobj=True)(pyfunc)
l = [3, 4, 5]
cfunc(l)
self.assertPreciseEqual(l, [])
with self.assertRaises(TypeError):
cfunc(42)
def test_starargs_non_tuple(self):
def consumer(*x):
return x
@jit(forceobj=True)
def foo(x):
return consumer(*x)
arg = "ijo"
got = foo(arg)
expect = foo.py_func(arg)
self.assertEqual(got, tuple(arg))
self.assertEqual(got, expect)
def test_expr_undef(self):
@jit(forceobj=True)
def foo():
# In Py3.12, this will emit a Expr.undef.
return [x for x in (1, 2)]
self.assertEqual(foo(), foo.py_func())
class TestObjectModeInvalidRewrite(TestCase):
"""
Tests to ensure that rewrite passes didn't affect objmode lowering.
"""
def _ensure_objmode(self, disp):
self.assertTrue(disp.signatures)
self.assertFalse(disp.nopython_signatures)
return disp
def test_static_raise_in_objmode_fallback(self):
"""
Test code based on user submitted issue at
https://github.com/numba/numba/issues/2159
"""
def test0(n):
return n
def test1(n):
if n == 0:
# static raise will fail in objmode if the IR is modified by
# rewrite pass
raise ValueError()
return test0(n) # trigger objmode fallback
compiled = jit(forceobj=True)(test1)
self.assertEqual(test1(10), compiled(10))
self._ensure_objmode(compiled)
def test_static_setitem_in_objmode_fallback(self):
"""
Test code based on user submitted issue at
https://github.com/numba/numba/issues/2169
"""
def test0(n):
return n
def test(a1, a2):
a1 = np.asarray(a1)
# static setitem here will fail in objmode if the IR is modified by
# rewrite pass
a2[0] = 1
return test0(a1.sum() + a2.sum()) # trigger objmode fallback
compiled = jit(forceobj=True)(test)
args = np.array([3]), np.array([4])
self.assertEqual(test(*args), compiled(*args))
self._ensure_objmode(compiled)
def test_dynamic_func_objmode(self):
"""
Test issue https://github.com/numba/numba/issues/3355
"""
func_text = "def func():\n"
func_text += " np.array([1,2,3])\n"
loc_vars = {}
custom_globals = {'np': np}
exec(func_text, custom_globals, loc_vars)
func = loc_vars['func']
jitted = jit(forceobj=True)(func)
jitted()
if __name__ == '__main__':
unittest.main()