ai-content-maker/.venv/Lib/site-packages/numba/tests/test_object_mode.py

192 lines
4.7 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
"""
Testing object mode specifics.
"""
import numpy as np
import unittest
from numba import jit
from numba.core import utils
from numba.tests.support import TestCase
def complex_constant(n):
tmp = n + 4
return tmp + 3j
def long_constant(n):
return n + 100000000000000000000000000000000000000000000000
def delitem_usecase(x):
del x[:]
def loop_nest_3(x, y):
n = 0
for i in range(x):
for j in range(y):
for k in range(x + y):
n += i * j
return n
def array_of_object(x):
return x
class TestObjectMode(TestCase):
def test_complex_constant(self):
pyfunc = complex_constant
cfunc = jit((), forceobj=True)(pyfunc)
self.assertPreciseEqual(pyfunc(12), cfunc(12))
def test_long_constant(self):
pyfunc = long_constant
cfunc = jit((), forceobj=True)(pyfunc)
self.assertPreciseEqual(pyfunc(12), cfunc(12))
def test_loop_nest(self):
"""
Test bug that decref the iterator early.
If the bug occurs, a segfault should occur
"""
pyfunc = loop_nest_3
cfunc = jit((), forceobj=True)(pyfunc)
self.assertEqual(pyfunc(5, 5), cfunc(5, 5))
def bm_pyfunc():
pyfunc(5, 5)
def bm_cfunc():
cfunc(5, 5)
utils.benchmark(bm_pyfunc)
utils.benchmark(bm_cfunc)
def test_array_of_object(self):
cfunc = jit(forceobj=True)(array_of_object)
objarr = np.array([object()] * 10)
self.assertIs(cfunc(objarr), objarr)
def test_sequence_contains(self):
"""
Test handling of the `in` comparison
"""
@jit(forceobj=True)
def foo(x, y):
return x in y
self.assertTrue(foo(1, [0, 1]))
self.assertTrue(foo(0, [0, 1]))
self.assertFalse(foo(2, [0, 1]))
with self.assertRaises(TypeError) as raises:
foo(None, None)
self.assertIn("is not iterable", str(raises.exception))
def test_delitem(self):
pyfunc = delitem_usecase
cfunc = jit((), forceobj=True)(pyfunc)
l = [3, 4, 5]
cfunc(l)
self.assertPreciseEqual(l, [])
with self.assertRaises(TypeError):
cfunc(42)
def test_starargs_non_tuple(self):
def consumer(*x):
return x
@jit(forceobj=True)
def foo(x):
return consumer(*x)
arg = "ijo"
got = foo(arg)
expect = foo.py_func(arg)
self.assertEqual(got, tuple(arg))
self.assertEqual(got, expect)
def test_expr_undef(self):
@jit(forceobj=True)
def foo():
# In Py3.12, this will emit a Expr.undef.
return [x for x in (1, 2)]
self.assertEqual(foo(), foo.py_func())
class TestObjectModeInvalidRewrite(TestCase):
"""
Tests to ensure that rewrite passes didn't affect objmode lowering.
"""
def _ensure_objmode(self, disp):
self.assertTrue(disp.signatures)
self.assertFalse(disp.nopython_signatures)
return disp
def test_static_raise_in_objmode_fallback(self):
"""
Test code based on user submitted issue at
https://github.com/numba/numba/issues/2159
"""
def test0(n):
return n
def test1(n):
if n == 0:
# static raise will fail in objmode if the IR is modified by
# rewrite pass
raise ValueError()
return test0(n) # trigger objmode fallback
compiled = jit(forceobj=True)(test1)
self.assertEqual(test1(10), compiled(10))
self._ensure_objmode(compiled)
def test_static_setitem_in_objmode_fallback(self):
"""
Test code based on user submitted issue at
https://github.com/numba/numba/issues/2169
"""
def test0(n):
return n
def test(a1, a2):
a1 = np.asarray(a1)
# static setitem here will fail in objmode if the IR is modified by
# rewrite pass
a2[0] = 1
return test0(a1.sum() + a2.sum()) # trigger objmode fallback
compiled = jit(forceobj=True)(test)
args = np.array([3]), np.array([4])
self.assertEqual(test(*args), compiled(*args))
self._ensure_objmode(compiled)
def test_dynamic_func_objmode(self):
"""
Test issue https://github.com/numba/numba/issues/3355
"""
func_text = "def func():\n"
func_text += " np.array([1,2,3])\n"
loc_vars = {}
custom_globals = {'np': np}
exec(func_text, custom_globals, loc_vars)
func = loc_vars['func']
jitted = jit(forceobj=True)(func)
jitted()
if __name__ == '__main__':
unittest.main()